diff options
-rw-r--r-- | train_lora.py | 159 | ||||
-rw-r--r-- | training/lora.py | 68 |
2 files changed, 89 insertions, 138 deletions
diff --git a/train_lora.py b/train_lora.py index ffc1d10..34e1008 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed | |||
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from PIL import Image | ||
20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
21 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 21 | from slugify import slugify |
@@ -24,8 +23,9 @@ from slugify import slugify | |||
24 | from common import load_text_embeddings | 23 | from common import load_text_embeddings |
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
27 | from training.lora import LoraAttention | 26 | from training.lora import LoraAttnProcessor |
28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | 30 | ||
31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -109,7 +109,7 @@ def parse_args(): | |||
109 | parser.add_argument( | 109 | parser.add_argument( |
110 | "--output_dir", | 110 | "--output_dir", |
111 | type=str, | 111 | type=str, |
112 | default="output/dreambooth", | 112 | default="output/lora", |
113 | help="The output directory where the model predictions and checkpoints will be written.", | 113 | help="The output directory where the model predictions and checkpoints will be written.", |
114 | ) | 114 | ) |
115 | parser.add_argument( | 115 | parser.add_argument( |
@@ -176,7 +176,7 @@ def parse_args(): | |||
176 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 176 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
177 | ) | 177 | ) |
178 | parser.add_argument( | 178 | parser.add_argument( |
179 | "--learning_rate_unet", | 179 | "--learning_rate", |
180 | type=float, | 180 | type=float, |
181 | default=2e-6, | 181 | default=2e-6, |
182 | help="Initial learning rate (after the potential warmup period) to use.", | 182 | help="Initial learning rate (after the potential warmup period) to use.", |
@@ -348,76 +348,45 @@ def parse_args(): | |||
348 | return args | 348 | return args |
349 | 349 | ||
350 | 350 | ||
351 | def save_args(basepath: Path, args, extra={}): | 351 | class Checkpointer(CheckpointerBase): |
352 | info = {"args": vars(args)} | ||
353 | info["args"].update(extra) | ||
354 | with open(basepath.joinpath("args.json"), "w") as f: | ||
355 | json.dump(info, f, indent=4) | ||
356 | |||
357 | |||
358 | def freeze_params(params): | ||
359 | for param in params: | ||
360 | param.requires_grad = False | ||
361 | |||
362 | |||
363 | def make_grid(images, rows, cols): | ||
364 | w, h = images[0].size | ||
365 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
366 | for i, image in enumerate(images): | ||
367 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
368 | return grid | ||
369 | |||
370 | |||
371 | class AverageMeter: | ||
372 | def __init__(self, name=None): | ||
373 | self.name = name | ||
374 | self.reset() | ||
375 | |||
376 | def reset(self): | ||
377 | self.sum = self.count = self.avg = 0 | ||
378 | |||
379 | def update(self, val, n=1): | ||
380 | self.sum += val * n | ||
381 | self.count += n | ||
382 | self.avg = self.sum / self.count | ||
383 | |||
384 | |||
385 | class Checkpointer: | ||
386 | def __init__( | 352 | def __init__( |
387 | self, | 353 | self, |
388 | datamodule, | 354 | datamodule, |
389 | accelerator, | 355 | accelerator, |
390 | vae, | 356 | vae, |
391 | unet, | 357 | unet, |
392 | unet_lora, | ||
393 | tokenizer, | 358 | tokenizer, |
394 | text_encoder, | 359 | text_encoder, |
360 | unet_lora, | ||
395 | scheduler, | 361 | scheduler, |
396 | output_dir: Path, | ||
397 | instance_identifier, | 362 | instance_identifier, |
398 | placeholder_token, | 363 | placeholder_token, |
399 | placeholder_token_id, | 364 | placeholder_token_id, |
365 | output_dir: Path, | ||
400 | sample_image_size, | 366 | sample_image_size, |
401 | sample_batches, | 367 | sample_batches, |
402 | sample_batch_size, | 368 | sample_batch_size, |
403 | seed | 369 | seed |
404 | ): | 370 | ): |
405 | self.datamodule = datamodule | 371 | super().__init__( |
372 | datamodule=datamodule, | ||
373 | output_dir=output_dir, | ||
374 | instance_identifier=instance_identifier, | ||
375 | placeholder_token=placeholder_token, | ||
376 | placeholder_token_id=placeholder_token_id, | ||
377 | sample_image_size=sample_image_size, | ||
378 | seed=seed or torch.random.seed(), | ||
379 | sample_batches=sample_batches, | ||
380 | sample_batch_size=sample_batch_size | ||
381 | ) | ||
382 | |||
406 | self.accelerator = accelerator | 383 | self.accelerator = accelerator |
407 | self.vae = vae | 384 | self.vae = vae |
408 | self.unet = unet | 385 | self.unet = unet |
409 | self.unet_lora = unet_lora | ||
410 | self.tokenizer = tokenizer | 386 | self.tokenizer = tokenizer |
411 | self.text_encoder = text_encoder | 387 | self.text_encoder = text_encoder |
388 | self.unet_lora = unet_lora | ||
412 | self.scheduler = scheduler | 389 | self.scheduler = scheduler |
413 | self.output_dir = output_dir | ||
414 | self.instance_identifier = instance_identifier | ||
415 | self.placeholder_token = placeholder_token | ||
416 | self.placeholder_token_id = placeholder_token_id | ||
417 | self.sample_image_size = sample_image_size | ||
418 | self.seed = seed or torch.random.seed() | ||
419 | self.sample_batches = sample_batches | ||
420 | self.sample_batch_size = sample_batch_size | ||
421 | 390 | ||
422 | @torch.no_grad() | 391 | @torch.no_grad() |
423 | def save_model(self): | 392 | def save_model(self): |
@@ -433,83 +402,18 @@ class Checkpointer: | |||
433 | 402 | ||
434 | @torch.no_grad() | 403 | @torch.no_grad() |
435 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 404 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
436 | samples_path = Path(self.output_dir).joinpath("samples") | 405 | # Save a sample image |
437 | |||
438 | unet = self.accelerator.unwrap_model(self.unet) | ||
439 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
440 | |||
441 | pipeline = VlpnStableDiffusion( | 406 | pipeline = VlpnStableDiffusion( |
442 | text_encoder=text_encoder, | 407 | text_encoder=self.text_encoder, |
443 | vae=self.vae, | 408 | vae=self.vae, |
444 | unet=unet, | 409 | unet=self.unet, |
445 | tokenizer=self.tokenizer, | 410 | tokenizer=self.tokenizer, |
446 | scheduler=self.scheduler, | 411 | scheduler=self.scheduler, |
447 | ).to(self.accelerator.device) | 412 | ).to(self.accelerator.device) |
448 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 413 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
449 | 414 | ||
450 | train_data = self.datamodule.train_dataloader() | 415 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
451 | val_data = self.datamodule.val_dataloader() | ||
452 | 416 | ||
453 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
454 | stable_latents = torch.randn( | ||
455 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
456 | device=pipeline.device, | ||
457 | generator=generator, | ||
458 | ) | ||
459 | |||
460 | with torch.autocast("cuda"), torch.inference_mode(): | ||
461 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
462 | all_samples = [] | ||
463 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
464 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
465 | |||
466 | data_enum = enumerate(data) | ||
467 | |||
468 | batches = [ | ||
469 | batch | ||
470 | for j, batch in data_enum | ||
471 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
472 | ] | ||
473 | prompts = [ | ||
474 | prompt.format(identifier=self.instance_identifier) | ||
475 | for batch in batches | ||
476 | for prompt in batch["prompts"] | ||
477 | ] | ||
478 | nprompts = [ | ||
479 | prompt | ||
480 | for batch in batches | ||
481 | for prompt in batch["nprompts"] | ||
482 | ] | ||
483 | |||
484 | for i in range(self.sample_batches): | ||
485 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
486 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
487 | |||
488 | samples = pipeline( | ||
489 | prompt=prompt, | ||
490 | negative_prompt=nprompt, | ||
491 | height=self.sample_image_size, | ||
492 | width=self.sample_image_size, | ||
493 | image=latents[:len(prompt)] if latents is not None else None, | ||
494 | generator=generator if latents is not None else None, | ||
495 | guidance_scale=guidance_scale, | ||
496 | eta=eta, | ||
497 | num_inference_steps=num_inference_steps, | ||
498 | output_type='pil' | ||
499 | ).images | ||
500 | |||
501 | all_samples += samples | ||
502 | |||
503 | del samples | ||
504 | |||
505 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
506 | image_grid.save(file_path, quality=85) | ||
507 | |||
508 | del all_samples | ||
509 | del image_grid | ||
510 | |||
511 | del unet | ||
512 | del text_encoder | ||
513 | del pipeline | 417 | del pipeline |
514 | del generator | 418 | del generator |
515 | del stable_latents | 419 | del stable_latents |
@@ -558,7 +462,11 @@ def main(): | |||
558 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 462 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
559 | args.pretrained_model_name_or_path, subfolder='scheduler') | 463 | args.pretrained_model_name_or_path, subfolder='scheduler') |
560 | 464 | ||
561 | unet_lora = LoraAttention() | 465 | unet_lora = LoraAttnProcessor( |
466 | cross_attention_dim=unet.cross_attention_dim, | ||
467 | inner_dim=unet.in_channels, | ||
468 | r=4, | ||
469 | ) | ||
562 | 470 | ||
563 | vae.enable_slicing() | 471 | vae.enable_slicing() |
564 | vae.set_use_memory_efficient_attention_xformers(True) | 472 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -618,8 +526,8 @@ def main(): | |||
618 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 526 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
619 | 527 | ||
620 | if args.scale_lr: | 528 | if args.scale_lr: |
621 | args.learning_rate_unet = ( | 529 | args.learning_rate = ( |
622 | args.learning_rate_unet * args.gradient_accumulation_steps * | 530 | args.learning_rate * args.gradient_accumulation_steps * |
623 | args.train_batch_size * accelerator.num_processes | 531 | args.train_batch_size * accelerator.num_processes |
624 | ) | 532 | ) |
625 | 533 | ||
@@ -639,7 +547,7 @@ def main(): | |||
639 | [ | 547 | [ |
640 | { | 548 | { |
641 | 'params': unet_lora.parameters(), | 549 | 'params': unet_lora.parameters(), |
642 | 'lr': args.learning_rate_unet, | 550 | 'lr': args.learning_rate, |
643 | }, | 551 | }, |
644 | ], | 552 | ], |
645 | betas=(args.adam_beta1, args.adam_beta2), | 553 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -801,7 +709,7 @@ def main(): | |||
801 | config = vars(args).copy() | 709 | config = vars(args).copy() |
802 | config["initializer_token"] = " ".join(config["initializer_token"]) | 710 | config["initializer_token"] = " ".join(config["initializer_token"]) |
803 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 711 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
804 | accelerator.init_trackers("dreambooth", config=config) | 712 | accelerator.init_trackers("lora", config=config) |
805 | 713 | ||
806 | # Train! | 714 | # Train! |
807 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 715 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
@@ -832,6 +740,7 @@ def main(): | |||
832 | tokenizer=tokenizer, | 740 | tokenizer=tokenizer, |
833 | text_encoder=text_encoder, | 741 | text_encoder=text_encoder, |
834 | scheduler=checkpoint_scheduler, | 742 | scheduler=checkpoint_scheduler, |
743 | unet_lora=unet_lora, | ||
835 | output_dir=basepath, | 744 | output_dir=basepath, |
836 | instance_identifier=instance_identifier, | 745 | instance_identifier=instance_identifier, |
837 | placeholder_token=args.placeholder_token, | 746 | placeholder_token=args.placeholder_token, |
diff --git a/training/lora.py b/training/lora.py index d8dc147..e1c0971 100644 --- a/training/lora.py +++ b/training/lora.py | |||
@@ -1,27 +1,69 @@ | |||
1 | import torch.nn as nn | 1 | import torch.nn as nn |
2 | from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config | ||
3 | 2 | ||
3 | from diffusers import ModelMixin, ConfigMixin | ||
4 | from diffusers.configuration_utils import register_to_config | ||
5 | from diffusers.models.cross_attention import CrossAttention | ||
6 | from diffusers.utils.import_utils import is_xformers_available | ||
4 | 7 | ||
5 | class LoraAttention(ModelMixin, ConfigMixin): | 8 | |
9 | if is_xformers_available(): | ||
10 | import xformers | ||
11 | import xformers.ops | ||
12 | else: | ||
13 | xformers = None | ||
14 | |||
15 | |||
16 | class LoraAttnProcessor(ModelMixin, ConfigMixin): | ||
6 | @register_to_config | 17 | @register_to_config |
7 | def __init__(self, in_features, out_features, r=4): | 18 | def __init__( |
19 | self, | ||
20 | cross_attention_dim, | ||
21 | inner_dim, | ||
22 | r: int = 4 | ||
23 | ): | ||
8 | super().__init__() | 24 | super().__init__() |
9 | 25 | ||
10 | if r > min(in_features, out_features): | 26 | if r > min(cross_attention_dim, inner_dim): |
11 | raise ValueError( | 27 | raise ValueError( |
12 | f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" | 28 | f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" |
13 | ) | 29 | ) |
14 | 30 | ||
15 | self.lora_down = nn.Linear(in_features, r, bias=False) | 31 | self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) |
16 | self.lora_up = nn.Linear(r, out_features, bias=False) | 32 | self.lora_k_up = nn.Linear(r, inner_dim, bias=False) |
33 | |||
34 | self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) | ||
35 | self.lora_v_up = nn.Linear(r, inner_dim, bias=False) | ||
36 | |||
17 | self.scale = 1.0 | 37 | self.scale = 1.0 |
18 | 38 | ||
19 | self.processor = XFormersCrossAttnProcessor() | 39 | nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) |
40 | nn.init.zeros_(self.lora_k_up.weight) | ||
41 | |||
42 | nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) | ||
43 | nn.init.zeros_(self.lora_v_up.weight) | ||
44 | |||
45 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
46 | batch_size, sequence_length, _ = hidden_states.shape | ||
47 | |||
48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
49 | |||
50 | query = attn.to_q(hidden_states) | ||
51 | |||
52 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
53 | key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale | ||
54 | value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale | ||
55 | |||
56 | query = attn.head_to_batch_dim(query).contiguous() | ||
57 | key = attn.head_to_batch_dim(key).contiguous() | ||
58 | value = attn.head_to_batch_dim(value).contiguous() | ||
59 | |||
60 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
61 | hidden_states = hidden_states.to(query.dtype) | ||
62 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
20 | 63 | ||
21 | nn.init.normal_(self.lora_down.weight, std=1 / r**2) | 64 | # linear proj |
22 | nn.init.zeros_(self.lora_up.weight) | 65 | hidden_states = attn.to_out[0](hidden_states) |
66 | # dropout | ||
67 | hidden_states = attn.to_out[1](hidden_states) | ||
23 | 68 | ||
24 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): | ||
25 | hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) | ||
26 | hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale | ||
27 | return hidden_states | 69 | return hidden_states |