diff options
| -rw-r--r-- | train_lora.py | 159 | ||||
| -rw-r--r-- | training/lora.py | 68 |
2 files changed, 89 insertions, 138 deletions
diff --git a/train_lora.py b/train_lora.py index ffc1d10..34e1008 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed | |||
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
| 19 | from PIL import Image | ||
| 20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 21 | from slugify import slugify |
| @@ -24,8 +23,9 @@ from slugify import slugify | |||
| 24 | from common import load_text_embeddings | 23 | from common import load_text_embeddings |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 27 | from training.lora import LoraAttention | 26 | from training.lora import LoraAttnProcessor |
| 28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -109,7 +109,7 @@ def parse_args(): | |||
| 109 | parser.add_argument( | 109 | parser.add_argument( |
| 110 | "--output_dir", | 110 | "--output_dir", |
| 111 | type=str, | 111 | type=str, |
| 112 | default="output/dreambooth", | 112 | default="output/lora", |
| 113 | help="The output directory where the model predictions and checkpoints will be written.", | 113 | help="The output directory where the model predictions and checkpoints will be written.", |
| 114 | ) | 114 | ) |
| 115 | parser.add_argument( | 115 | parser.add_argument( |
| @@ -176,7 +176,7 @@ def parse_args(): | |||
| 176 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 176 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", |
| 177 | ) | 177 | ) |
| 178 | parser.add_argument( | 178 | parser.add_argument( |
| 179 | "--learning_rate_unet", | 179 | "--learning_rate", |
| 180 | type=float, | 180 | type=float, |
| 181 | default=2e-6, | 181 | default=2e-6, |
| 182 | help="Initial learning rate (after the potential warmup period) to use.", | 182 | help="Initial learning rate (after the potential warmup period) to use.", |
| @@ -348,76 +348,45 @@ def parse_args(): | |||
| 348 | return args | 348 | return args |
| 349 | 349 | ||
| 350 | 350 | ||
| 351 | def save_args(basepath: Path, args, extra={}): | 351 | class Checkpointer(CheckpointerBase): |
| 352 | info = {"args": vars(args)} | ||
| 353 | info["args"].update(extra) | ||
| 354 | with open(basepath.joinpath("args.json"), "w") as f: | ||
| 355 | json.dump(info, f, indent=4) | ||
| 356 | |||
| 357 | |||
| 358 | def freeze_params(params): | ||
| 359 | for param in params: | ||
| 360 | param.requires_grad = False | ||
| 361 | |||
| 362 | |||
| 363 | def make_grid(images, rows, cols): | ||
| 364 | w, h = images[0].size | ||
| 365 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
| 366 | for i, image in enumerate(images): | ||
| 367 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
| 368 | return grid | ||
| 369 | |||
| 370 | |||
| 371 | class AverageMeter: | ||
| 372 | def __init__(self, name=None): | ||
| 373 | self.name = name | ||
| 374 | self.reset() | ||
| 375 | |||
| 376 | def reset(self): | ||
| 377 | self.sum = self.count = self.avg = 0 | ||
| 378 | |||
| 379 | def update(self, val, n=1): | ||
| 380 | self.sum += val * n | ||
| 381 | self.count += n | ||
| 382 | self.avg = self.sum / self.count | ||
| 383 | |||
| 384 | |||
| 385 | class Checkpointer: | ||
| 386 | def __init__( | 352 | def __init__( |
| 387 | self, | 353 | self, |
| 388 | datamodule, | 354 | datamodule, |
| 389 | accelerator, | 355 | accelerator, |
| 390 | vae, | 356 | vae, |
| 391 | unet, | 357 | unet, |
| 392 | unet_lora, | ||
| 393 | tokenizer, | 358 | tokenizer, |
| 394 | text_encoder, | 359 | text_encoder, |
| 360 | unet_lora, | ||
| 395 | scheduler, | 361 | scheduler, |
| 396 | output_dir: Path, | ||
| 397 | instance_identifier, | 362 | instance_identifier, |
| 398 | placeholder_token, | 363 | placeholder_token, |
| 399 | placeholder_token_id, | 364 | placeholder_token_id, |
| 365 | output_dir: Path, | ||
| 400 | sample_image_size, | 366 | sample_image_size, |
| 401 | sample_batches, | 367 | sample_batches, |
| 402 | sample_batch_size, | 368 | sample_batch_size, |
| 403 | seed | 369 | seed |
| 404 | ): | 370 | ): |
| 405 | self.datamodule = datamodule | 371 | super().__init__( |
| 372 | datamodule=datamodule, | ||
| 373 | output_dir=output_dir, | ||
| 374 | instance_identifier=instance_identifier, | ||
| 375 | placeholder_token=placeholder_token, | ||
| 376 | placeholder_token_id=placeholder_token_id, | ||
| 377 | sample_image_size=sample_image_size, | ||
| 378 | seed=seed or torch.random.seed(), | ||
| 379 | sample_batches=sample_batches, | ||
| 380 | sample_batch_size=sample_batch_size | ||
| 381 | ) | ||
| 382 | |||
| 406 | self.accelerator = accelerator | 383 | self.accelerator = accelerator |
| 407 | self.vae = vae | 384 | self.vae = vae |
| 408 | self.unet = unet | 385 | self.unet = unet |
| 409 | self.unet_lora = unet_lora | ||
| 410 | self.tokenizer = tokenizer | 386 | self.tokenizer = tokenizer |
| 411 | self.text_encoder = text_encoder | 387 | self.text_encoder = text_encoder |
| 388 | self.unet_lora = unet_lora | ||
| 412 | self.scheduler = scheduler | 389 | self.scheduler = scheduler |
| 413 | self.output_dir = output_dir | ||
| 414 | self.instance_identifier = instance_identifier | ||
| 415 | self.placeholder_token = placeholder_token | ||
| 416 | self.placeholder_token_id = placeholder_token_id | ||
| 417 | self.sample_image_size = sample_image_size | ||
| 418 | self.seed = seed or torch.random.seed() | ||
| 419 | self.sample_batches = sample_batches | ||
| 420 | self.sample_batch_size = sample_batch_size | ||
| 421 | 390 | ||
| 422 | @torch.no_grad() | 391 | @torch.no_grad() |
| 423 | def save_model(self): | 392 | def save_model(self): |
| @@ -433,83 +402,18 @@ class Checkpointer: | |||
| 433 | 402 | ||
| 434 | @torch.no_grad() | 403 | @torch.no_grad() |
| 435 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 404 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 436 | samples_path = Path(self.output_dir).joinpath("samples") | 405 | # Save a sample image |
| 437 | |||
| 438 | unet = self.accelerator.unwrap_model(self.unet) | ||
| 439 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 440 | |||
| 441 | pipeline = VlpnStableDiffusion( | 406 | pipeline = VlpnStableDiffusion( |
| 442 | text_encoder=text_encoder, | 407 | text_encoder=self.text_encoder, |
| 443 | vae=self.vae, | 408 | vae=self.vae, |
| 444 | unet=unet, | 409 | unet=self.unet, |
| 445 | tokenizer=self.tokenizer, | 410 | tokenizer=self.tokenizer, |
| 446 | scheduler=self.scheduler, | 411 | scheduler=self.scheduler, |
| 447 | ).to(self.accelerator.device) | 412 | ).to(self.accelerator.device) |
| 448 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 413 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 449 | 414 | ||
| 450 | train_data = self.datamodule.train_dataloader() | 415 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 451 | val_data = self.datamodule.val_dataloader() | ||
| 452 | |||
| 453 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
| 454 | stable_latents = torch.randn( | ||
| 455 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
| 456 | device=pipeline.device, | ||
| 457 | generator=generator, | ||
| 458 | ) | ||
| 459 | |||
| 460 | with torch.autocast("cuda"), torch.inference_mode(): | ||
| 461 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
| 462 | all_samples = [] | ||
| 463 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
| 464 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 465 | 416 | ||
| 466 | data_enum = enumerate(data) | ||
| 467 | |||
| 468 | batches = [ | ||
| 469 | batch | ||
| 470 | for j, batch in data_enum | ||
| 471 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 472 | ] | ||
| 473 | prompts = [ | ||
| 474 | prompt.format(identifier=self.instance_identifier) | ||
| 475 | for batch in batches | ||
| 476 | for prompt in batch["prompts"] | ||
| 477 | ] | ||
| 478 | nprompts = [ | ||
| 479 | prompt | ||
| 480 | for batch in batches | ||
| 481 | for prompt in batch["nprompts"] | ||
| 482 | ] | ||
| 483 | |||
| 484 | for i in range(self.sample_batches): | ||
| 485 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 486 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 487 | |||
| 488 | samples = pipeline( | ||
| 489 | prompt=prompt, | ||
| 490 | negative_prompt=nprompt, | ||
| 491 | height=self.sample_image_size, | ||
| 492 | width=self.sample_image_size, | ||
| 493 | image=latents[:len(prompt)] if latents is not None else None, | ||
| 494 | generator=generator if latents is not None else None, | ||
| 495 | guidance_scale=guidance_scale, | ||
| 496 | eta=eta, | ||
| 497 | num_inference_steps=num_inference_steps, | ||
| 498 | output_type='pil' | ||
| 499 | ).images | ||
| 500 | |||
| 501 | all_samples += samples | ||
| 502 | |||
| 503 | del samples | ||
| 504 | |||
| 505 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
| 506 | image_grid.save(file_path, quality=85) | ||
| 507 | |||
| 508 | del all_samples | ||
| 509 | del image_grid | ||
| 510 | |||
| 511 | del unet | ||
| 512 | del text_encoder | ||
| 513 | del pipeline | 417 | del pipeline |
| 514 | del generator | 418 | del generator |
| 515 | del stable_latents | 419 | del stable_latents |
| @@ -558,7 +462,11 @@ def main(): | |||
| 558 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 462 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 559 | args.pretrained_model_name_or_path, subfolder='scheduler') | 463 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 560 | 464 | ||
| 561 | unet_lora = LoraAttention() | 465 | unet_lora = LoraAttnProcessor( |
| 466 | cross_attention_dim=unet.cross_attention_dim, | ||
| 467 | inner_dim=unet.in_channels, | ||
| 468 | r=4, | ||
| 469 | ) | ||
| 562 | 470 | ||
| 563 | vae.enable_slicing() | 471 | vae.enable_slicing() |
| 564 | vae.set_use_memory_efficient_attention_xformers(True) | 472 | vae.set_use_memory_efficient_attention_xformers(True) |
| @@ -618,8 +526,8 @@ def main(): | |||
| 618 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 526 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 619 | 527 | ||
| 620 | if args.scale_lr: | 528 | if args.scale_lr: |
| 621 | args.learning_rate_unet = ( | 529 | args.learning_rate = ( |
| 622 | args.learning_rate_unet * args.gradient_accumulation_steps * | 530 | args.learning_rate * args.gradient_accumulation_steps * |
| 623 | args.train_batch_size * accelerator.num_processes | 531 | args.train_batch_size * accelerator.num_processes |
| 624 | ) | 532 | ) |
| 625 | 533 | ||
| @@ -639,7 +547,7 @@ def main(): | |||
| 639 | [ | 547 | [ |
| 640 | { | 548 | { |
| 641 | 'params': unet_lora.parameters(), | 549 | 'params': unet_lora.parameters(), |
| 642 | 'lr': args.learning_rate_unet, | 550 | 'lr': args.learning_rate, |
| 643 | }, | 551 | }, |
| 644 | ], | 552 | ], |
| 645 | betas=(args.adam_beta1, args.adam_beta2), | 553 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -801,7 +709,7 @@ def main(): | |||
| 801 | config = vars(args).copy() | 709 | config = vars(args).copy() |
| 802 | config["initializer_token"] = " ".join(config["initializer_token"]) | 710 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 803 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 711 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 804 | accelerator.init_trackers("dreambooth", config=config) | 712 | accelerator.init_trackers("lora", config=config) |
| 805 | 713 | ||
| 806 | # Train! | 714 | # Train! |
| 807 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 715 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| @@ -832,6 +740,7 @@ def main(): | |||
| 832 | tokenizer=tokenizer, | 740 | tokenizer=tokenizer, |
| 833 | text_encoder=text_encoder, | 741 | text_encoder=text_encoder, |
| 834 | scheduler=checkpoint_scheduler, | 742 | scheduler=checkpoint_scheduler, |
| 743 | unet_lora=unet_lora, | ||
| 835 | output_dir=basepath, | 744 | output_dir=basepath, |
| 836 | instance_identifier=instance_identifier, | 745 | instance_identifier=instance_identifier, |
| 837 | placeholder_token=args.placeholder_token, | 746 | placeholder_token=args.placeholder_token, |
diff --git a/training/lora.py b/training/lora.py index d8dc147..e1c0971 100644 --- a/training/lora.py +++ b/training/lora.py | |||
| @@ -1,27 +1,69 @@ | |||
| 1 | import torch.nn as nn | 1 | import torch.nn as nn |
| 2 | from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config | ||
| 3 | 2 | ||
| 3 | from diffusers import ModelMixin, ConfigMixin | ||
| 4 | from diffusers.configuration_utils import register_to_config | ||
| 5 | from diffusers.models.cross_attention import CrossAttention | ||
| 6 | from diffusers.utils.import_utils import is_xformers_available | ||
| 4 | 7 | ||
| 5 | class LoraAttention(ModelMixin, ConfigMixin): | 8 | |
| 9 | if is_xformers_available(): | ||
| 10 | import xformers | ||
| 11 | import xformers.ops | ||
| 12 | else: | ||
| 13 | xformers = None | ||
| 14 | |||
| 15 | |||
| 16 | class LoraAttnProcessor(ModelMixin, ConfigMixin): | ||
| 6 | @register_to_config | 17 | @register_to_config |
| 7 | def __init__(self, in_features, out_features, r=4): | 18 | def __init__( |
| 19 | self, | ||
| 20 | cross_attention_dim, | ||
| 21 | inner_dim, | ||
| 22 | r: int = 4 | ||
| 23 | ): | ||
| 8 | super().__init__() | 24 | super().__init__() |
| 9 | 25 | ||
| 10 | if r > min(in_features, out_features): | 26 | if r > min(cross_attention_dim, inner_dim): |
| 11 | raise ValueError( | 27 | raise ValueError( |
| 12 | f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" | 28 | f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" |
| 13 | ) | 29 | ) |
| 14 | 30 | ||
| 15 | self.lora_down = nn.Linear(in_features, r, bias=False) | 31 | self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) |
| 16 | self.lora_up = nn.Linear(r, out_features, bias=False) | 32 | self.lora_k_up = nn.Linear(r, inner_dim, bias=False) |
| 33 | |||
| 34 | self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) | ||
| 35 | self.lora_v_up = nn.Linear(r, inner_dim, bias=False) | ||
| 36 | |||
| 17 | self.scale = 1.0 | 37 | self.scale = 1.0 |
| 18 | 38 | ||
| 19 | self.processor = XFormersCrossAttnProcessor() | 39 | nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) |
| 40 | nn.init.zeros_(self.lora_k_up.weight) | ||
| 41 | |||
| 42 | nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) | ||
| 43 | nn.init.zeros_(self.lora_v_up.weight) | ||
| 44 | |||
| 45 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | ||
| 46 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 47 | |||
| 48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
| 49 | |||
| 50 | query = attn.to_q(hidden_states) | ||
| 51 | |||
| 52 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 53 | key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale | ||
| 54 | value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale | ||
| 55 | |||
| 56 | query = attn.head_to_batch_dim(query).contiguous() | ||
| 57 | key = attn.head_to_batch_dim(key).contiguous() | ||
| 58 | value = attn.head_to_batch_dim(value).contiguous() | ||
| 59 | |||
| 60 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
| 61 | hidden_states = hidden_states.to(query.dtype) | ||
| 62 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 20 | 63 | ||
| 21 | nn.init.normal_(self.lora_down.weight, std=1 / r**2) | 64 | # linear proj |
| 22 | nn.init.zeros_(self.lora_up.weight) | 65 | hidden_states = attn.to_out[0](hidden_states) |
| 66 | # dropout | ||
| 67 | hidden_states = attn.to_out[1](hidden_states) | ||
| 23 | 68 | ||
| 24 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): | ||
| 25 | hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number) | ||
| 26 | hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale | ||
| 27 | return hidden_states | 69 | return hidden_states |
