diff options
| -rw-r--r-- | data/csv.py | 11 | ||||
| -rw-r--r-- | train_ti.py | 74 | ||||
| -rw-r--r-- | training/functional.py | 100 | ||||
| -rw-r--r-- | training/lr.py | 29 | ||||
| -rw-r--r-- | training/strategy/ti.py | 54 | 
5 files changed, 106 insertions, 162 deletions
diff --git a/data/csv.py b/data/csv.py index b058a3e..5de3ac7 100644 --- a/data/csv.py +++ b/data/csv.py  | |||
| @@ -100,28 +100,25 @@ def generate_buckets( | |||
| 100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments | 
| 101 | 101 | ||
| 102 | 102 | ||
| 103 | def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): | 103 | def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): | 
| 104 | with_prior = all("class_prompt_ids" in example for example in examples) | ||
| 105 | |||
| 106 | prompt_ids = [example["prompt_ids"] for example in examples] | 104 | prompt_ids = [example["prompt_ids"] for example in examples] | 
| 107 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 105 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 
| 108 | 106 | ||
| 109 | input_ids = [example["instance_prompt_ids"] for example in examples] | 107 | input_ids = [example["instance_prompt_ids"] for example in examples] | 
| 110 | pixel_values = [example["instance_images"] for example in examples] | 108 | pixel_values = [example["instance_images"] for example in examples] | 
| 111 | 109 | ||
| 112 | if with_prior: | 110 | if with_prior_preservation: | 
| 113 | input_ids += [example["class_prompt_ids"] for example in examples] | 111 | input_ids += [example["class_prompt_ids"] for example in examples] | 
| 114 | pixel_values += [example["class_images"] for example in examples] | 112 | pixel_values += [example["class_images"] for example in examples] | 
| 115 | 113 | ||
| 116 | pixel_values = torch.stack(pixel_values) | 114 | pixel_values = torch.stack(pixel_values) | 
| 117 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 115 | pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) | 
| 118 | 116 | ||
| 119 | prompts = unify_input_ids(tokenizer, prompt_ids) | 117 | prompts = unify_input_ids(tokenizer, prompt_ids) | 
| 120 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 118 | nprompts = unify_input_ids(tokenizer, nprompt_ids) | 
| 121 | inputs = unify_input_ids(tokenizer, input_ids) | 119 | inputs = unify_input_ids(tokenizer, input_ids) | 
| 122 | 120 | ||
| 123 | batch = { | 121 | batch = { | 
| 124 | "with_prior": torch.tensor([with_prior] * len(examples)), | ||
| 125 | "prompt_ids": prompts.input_ids, | 122 | "prompt_ids": prompts.input_ids, | 
| 126 | "nprompt_ids": nprompts.input_ids, | 123 | "nprompt_ids": nprompts.input_ids, | 
| 127 | "input_ids": inputs.input_ids, | 124 | "input_ids": inputs.input_ids, | 
| @@ -285,7 +282,7 @@ class VlpnDataModule(): | |||
| 285 | size=self.size, interpolation=self.interpolation, | 282 | size=self.size, interpolation=self.interpolation, | 
| 286 | ) | 283 | ) | 
| 287 | 284 | ||
| 288 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) | 285 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 
| 289 | 286 | ||
| 290 | self.train_dataloader = DataLoader( | 287 | self.train_dataloader = DataLoader( | 
| 291 | train_dataset, | 288 | train_dataset, | 
diff --git a/train_ti.py b/train_ti.py index 3c9810f..4bac736 100644 --- a/train_ti.py +++ b/train_ti.py  | |||
| @@ -15,11 +15,11 @@ from slugify import slugify | |||
| 15 | 15 | ||
| 16 | from util import load_config, load_embeddings_from_dir | 16 | from util import load_config, load_embeddings_from_dir | 
| 17 | from data.csv import VlpnDataModule, VlpnDataItem | 17 | from data.csv import VlpnDataModule, VlpnDataItem | 
| 18 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models | 
| 19 | from training.strategy.ti import textual_inversion_strategy | 19 | from training.strategy.ti import textual_inversion_strategy | 
| 20 | from training.optimization import get_scheduler | 20 | from training.optimization import get_scheduler | 
| 21 | from training.lr import LRFinder | 21 | from training.lr import LRFinder | 
| 22 | from training.util import EMAModel, save_args | 22 | from training.util import save_args | 
| 23 | 23 | ||
| 24 | logger = get_logger(__name__) | 24 | logger = get_logger(__name__) | 
| 25 | 25 | ||
| @@ -82,7 +82,7 @@ def parse_args(): | |||
| 82 | parser.add_argument( | 82 | parser.add_argument( | 
| 83 | "--num_class_images", | 83 | "--num_class_images", | 
| 84 | type=int, | 84 | type=int, | 
| 85 | default=1, | 85 | default=0, | 
| 86 | help="How many class images to generate." | 86 | help="How many class images to generate." | 
| 87 | ) | 87 | ) | 
| 88 | parser.add_argument( | 88 | parser.add_argument( | 
| @@ -398,7 +398,7 @@ def parse_args(): | |||
| 398 | ) | 398 | ) | 
| 399 | parser.add_argument( | 399 | parser.add_argument( | 
| 400 | "--emb_decay_factor", | 400 | "--emb_decay_factor", | 
| 401 | default=0, | 401 | default=1, | 
| 402 | type=float, | 402 | type=float, | 
| 403 | help="Embedding decay factor." | 403 | help="Embedding decay factor." | 
| 404 | ) | 404 | ) | 
| @@ -540,16 +540,6 @@ def main(): | |||
| 540 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | 540 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | 
| 541 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | 541 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | 
| 542 | 542 | ||
| 543 | if args.use_ema: | ||
| 544 | ema_embeddings = EMAModel( | ||
| 545 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 546 | inv_gamma=args.ema_inv_gamma, | ||
| 547 | power=args.ema_power, | ||
| 548 | max_value=args.ema_max_decay, | ||
| 549 | ) | ||
| 550 | else: | ||
| 551 | ema_embeddings = None | ||
| 552 | |||
| 553 | if args.scale_lr: | 543 | if args.scale_lr: | 
| 554 | args.learning_rate = ( | 544 | args.learning_rate = ( | 
| 555 | args.learning_rate * args.gradient_accumulation_steps * | 545 | args.learning_rate * args.gradient_accumulation_steps * | 
| @@ -654,23 +644,13 @@ def main(): | |||
| 654 | warmup_epochs=args.lr_warmup_epochs, | 644 | warmup_epochs=args.lr_warmup_epochs, | 
| 655 | ) | 645 | ) | 
| 656 | 646 | ||
| 657 | if args.use_ema: | 647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 
| 658 | ema_embeddings.to(accelerator.device) | 648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 
| 659 | |||
| 660 | trainer = partial( | ||
| 661 | train, | ||
| 662 | accelerator=accelerator, | ||
| 663 | vae=vae, | ||
| 664 | unet=unet, | ||
| 665 | text_encoder=text_encoder, | ||
| 666 | noise_scheduler=noise_scheduler, | ||
| 667 | train_dataloader=train_dataloader, | ||
| 668 | val_dataloader=val_dataloader, | ||
| 669 | dtype=weight_dtype, | ||
| 670 | seed=args.seed, | ||
| 671 | ) | 649 | ) | 
| 672 | 650 | ||
| 673 | strategy = textual_inversion_strategy( | 651 | vae.to(accelerator.device, dtype=weight_dtype) | 
| 652 | |||
| 653 | callbacks = textual_inversion_strategy( | ||
| 674 | accelerator=accelerator, | 654 | accelerator=accelerator, | 
| 675 | unet=unet, | 655 | unet=unet, | 
| 676 | text_encoder=text_encoder, | 656 | text_encoder=text_encoder, | 
| @@ -679,7 +659,6 @@ def main(): | |||
| 679 | sample_scheduler=sample_scheduler, | 659 | sample_scheduler=sample_scheduler, | 
| 680 | train_dataloader=train_dataloader, | 660 | train_dataloader=train_dataloader, | 
| 681 | val_dataloader=val_dataloader, | 661 | val_dataloader=val_dataloader, | 
| 682 | dtype=weight_dtype, | ||
| 683 | output_dir=output_dir, | 662 | output_dir=output_dir, | 
| 684 | seed=args.seed, | 663 | seed=args.seed, | 
| 685 | placeholder_tokens=args.placeholder_tokens, | 664 | placeholder_tokens=args.placeholder_tokens, | 
| @@ -700,31 +679,54 @@ def main(): | |||
| 700 | sample_image_size=args.sample_image_size, | 679 | sample_image_size=args.sample_image_size, | 
| 701 | ) | 680 | ) | 
| 702 | 681 | ||
| 682 | for model in (unet, text_encoder, vae): | ||
| 683 | model.requires_grad_(False) | ||
| 684 | model.eval() | ||
| 685 | |||
| 686 | callbacks.on_prepare() | ||
| 687 | |||
| 688 | loss_step_ = partial( | ||
| 689 | loss_step, | ||
| 690 | vae, | ||
| 691 | noise_scheduler, | ||
| 692 | unet, | ||
| 693 | text_encoder, | ||
| 694 | args.num_class_images != 0, | ||
| 695 | args.prior_loss_weight, | ||
| 696 | args.seed, | ||
| 697 | ) | ||
| 698 | |||
| 703 | if args.find_lr: | 699 | if args.find_lr: | 
| 704 | lr_finder = LRFinder( | 700 | lr_finder = LRFinder( | 
| 705 | accelerator=accelerator, | 701 | accelerator=accelerator, | 
| 706 | optimizer=optimizer, | 702 | optimizer=optimizer, | 
| 707 | model=text_encoder, | ||
| 708 | train_dataloader=train_dataloader, | 703 | train_dataloader=train_dataloader, | 
| 709 | val_dataloader=val_dataloader, | 704 | val_dataloader=val_dataloader, | 
| 710 | **strategy, | 705 | callbacks=callbacks, | 
| 711 | ) | 706 | ) | 
| 712 | lr_finder.run(num_epochs=100, end_lr=1e3) | 707 | lr_finder.run(num_epochs=100, end_lr=1e3) | 
| 713 | 708 | ||
| 714 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | 709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | 
| 715 | plt.close() | 710 | plt.close() | 
| 716 | else: | 711 | else: | 
| 717 | trainer( | 712 | if accelerator.is_main_process: | 
| 713 | accelerator.init_trackers("textual_inversion") | ||
| 714 | |||
| 715 | train_loop( | ||
| 716 | accelerator=accelerator, | ||
| 718 | optimizer=optimizer, | 717 | optimizer=optimizer, | 
| 719 | lr_scheduler=lr_scheduler, | 718 | lr_scheduler=lr_scheduler, | 
| 720 | num_train_epochs=args.num_train_epochs, | 719 | train_dataloader=train_dataloader, | 
| 720 | val_dataloader=val_dataloader, | ||
| 721 | loss_step=loss_step_, | ||
| 721 | sample_frequency=args.sample_frequency, | 722 | sample_frequency=args.sample_frequency, | 
| 722 | checkpoint_frequency=args.checkpoint_frequency, | 723 | checkpoint_frequency=args.checkpoint_frequency, | 
| 723 | global_step_offset=global_step_offset, | 724 | global_step_offset=global_step_offset, | 
| 724 | prior_loss_weight=args.prior_loss_weight, | 725 | callbacks=callbacks, | 
| 725 | callbacks=strategy, | ||
| 726 | ) | 726 | ) | 
| 727 | 727 | ||
| 728 | accelerator.end_training() | ||
| 729 | |||
| 728 | 730 | ||
| 729 | if __name__ == "__main__": | 731 | if __name__ == "__main__": | 
| 730 | main() | 732 | main() | 
diff --git a/training/functional.py b/training/functional.py index 4ca7470..c01595a 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -33,6 +33,7 @@ def const(result=None): | |||
| 33 | @dataclass | 33 | @dataclass | 
| 34 | class TrainingCallbacks(): | 34 | class TrainingCallbacks(): | 
| 35 | on_prepare: Callable[[float], None] = const() | 35 | on_prepare: Callable[[float], None] = const() | 
| 36 | on_model: Callable[[], torch.nn.Module] = const(None) | ||
| 36 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) | 
| 37 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 
| 38 | on_before_optimize: Callable[[int], None] = const() | 39 | on_before_optimize: Callable[[int], None] = const() | 
| @@ -267,6 +268,7 @@ def loss_step( | |||
| 267 | noise_scheduler: DDPMScheduler, | 268 | noise_scheduler: DDPMScheduler, | 
| 268 | unet: UNet2DConditionModel, | 269 | unet: UNet2DConditionModel, | 
| 269 | text_encoder: CLIPTextModel, | 270 | text_encoder: CLIPTextModel, | 
| 271 | with_prior_preservation: bool, | ||
| 270 | prior_loss_weight: float, | 272 | prior_loss_weight: float, | 
| 271 | seed: int, | 273 | seed: int, | 
| 272 | step: int, | 274 | step: int, | 
| @@ -322,7 +324,7 @@ def loss_step( | |||
| 322 | else: | 324 | else: | 
| 323 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 325 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 
| 324 | 326 | ||
| 325 | if batch["with_prior"].all(): | 327 | if with_prior_preservation: | 
| 326 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 328 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 
| 327 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 329 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 
| 328 | target, target_prior = torch.chunk(target, 2, dim=0) | 330 | target, target_prior = torch.chunk(target, 2, dim=0) | 
| @@ -347,7 +349,6 @@ def train_loop( | |||
| 347 | accelerator: Accelerator, | 349 | accelerator: Accelerator, | 
| 348 | optimizer: torch.optim.Optimizer, | 350 | optimizer: torch.optim.Optimizer, | 
| 349 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 351 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 
| 350 | model: torch.nn.Module, | ||
| 351 | train_dataloader: DataLoader, | 352 | train_dataloader: DataLoader, | 
| 352 | val_dataloader: DataLoader, | 353 | val_dataloader: DataLoader, | 
| 353 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 354 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 
| @@ -387,28 +388,37 @@ def train_loop( | |||
| 387 | ) | 388 | ) | 
| 388 | global_progress_bar.set_description("Total progress") | 389 | global_progress_bar.set_description("Total progress") | 
| 389 | 390 | ||
| 391 | model = callbacks.on_model() | ||
| 392 | on_log = callbacks.on_log | ||
| 393 | on_train = callbacks.on_train | ||
| 394 | on_before_optimize = callbacks.on_before_optimize | ||
| 395 | on_after_optimize = callbacks.on_after_optimize | ||
| 396 | on_eval = callbacks.on_eval | ||
| 397 | on_sample = callbacks.on_sample | ||
| 398 | on_checkpoint = callbacks.on_checkpoint | ||
| 399 | |||
| 390 | try: | 400 | try: | 
| 391 | for epoch in range(num_epochs): | 401 | for epoch in range(num_epochs): | 
| 392 | if accelerator.is_main_process: | 402 | if accelerator.is_main_process: | 
| 393 | if epoch % sample_frequency == 0: | 403 | if epoch % sample_frequency == 0: | 
| 394 | callbacks.on_sample(global_step + global_step_offset) | 404 | on_sample(global_step + global_step_offset) | 
| 395 | 405 | ||
| 396 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 406 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 
| 397 | callbacks.on_checkpoint(global_step + global_step_offset, "training") | 407 | on_checkpoint(global_step + global_step_offset, "training") | 
| 398 | 408 | ||
| 399 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 409 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 
| 400 | local_progress_bar.reset() | 410 | local_progress_bar.reset() | 
| 401 | 411 | ||
| 402 | model.train() | 412 | model.train() | 
| 403 | 413 | ||
| 404 | with callbacks.on_train(epoch): | 414 | with on_train(epoch): | 
| 405 | for step, batch in enumerate(train_dataloader): | 415 | for step, batch in enumerate(train_dataloader): | 
| 406 | with accelerator.accumulate(model): | 416 | with accelerator.accumulate(model): | 
| 407 | loss, acc, bsz = loss_step(step, batch) | 417 | loss, acc, bsz = loss_step(step, batch) | 
| 408 | 418 | ||
| 409 | accelerator.backward(loss) | 419 | accelerator.backward(loss) | 
| 410 | 420 | ||
| 411 | callbacks.on_before_optimize(epoch) | 421 | on_before_optimize(epoch) | 
| 412 | 422 | ||
| 413 | optimizer.step() | 423 | optimizer.step() | 
| 414 | lr_scheduler.step() | 424 | lr_scheduler.step() | 
| @@ -419,7 +429,7 @@ def train_loop( | |||
| 419 | 429 | ||
| 420 | # Checks if the accelerator has performed an optimization step behind the scenes | 430 | # Checks if the accelerator has performed an optimization step behind the scenes | 
| 421 | if accelerator.sync_gradients: | 431 | if accelerator.sync_gradients: | 
| 422 | callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) | 432 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 
| 423 | 433 | ||
| 424 | local_progress_bar.update(1) | 434 | local_progress_bar.update(1) | 
| 425 | global_progress_bar.update(1) | 435 | global_progress_bar.update(1) | 
| @@ -433,7 +443,7 @@ def train_loop( | |||
| 433 | "train/cur_acc": acc.item(), | 443 | "train/cur_acc": acc.item(), | 
| 434 | "lr": lr_scheduler.get_last_lr()[0], | 444 | "lr": lr_scheduler.get_last_lr()[0], | 
| 435 | } | 445 | } | 
| 436 | logs.update(callbacks.on_log()) | 446 | logs.update(on_log()) | 
| 437 | 447 | ||
| 438 | accelerator.log(logs, step=global_step) | 448 | accelerator.log(logs, step=global_step) | 
| 439 | 449 | ||
| @@ -449,7 +459,7 @@ def train_loop( | |||
| 449 | cur_loss_val = AverageMeter() | 459 | cur_loss_val = AverageMeter() | 
| 450 | cur_acc_val = AverageMeter() | 460 | cur_acc_val = AverageMeter() | 
| 451 | 461 | ||
| 452 | with torch.inference_mode(), callbacks.on_eval(): | 462 | with torch.inference_mode(), on_eval(): | 
| 453 | for step, batch in enumerate(val_dataloader): | 463 | for step, batch in enumerate(val_dataloader): | 
| 454 | loss, acc, bsz = loss_step(step, batch, True) | 464 | loss, acc, bsz = loss_step(step, batch, True) | 
| 455 | 465 | ||
| @@ -485,80 +495,16 @@ def train_loop( | |||
| 485 | if avg_acc_val.avg.item() > max_acc_val: | 495 | if avg_acc_val.avg.item() > max_acc_val: | 
| 486 | accelerator.print( | 496 | accelerator.print( | 
| 487 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 497 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 
| 488 | callbacks.on_checkpoint(global_step + global_step_offset, "milestone") | 498 | on_checkpoint(global_step + global_step_offset, "milestone") | 
| 489 | max_acc_val = avg_acc_val.avg.item() | 499 | max_acc_val = avg_acc_val.avg.item() | 
| 490 | 500 | ||
| 491 | # Create the pipeline using using the trained modules and save it. | 501 | # Create the pipeline using using the trained modules and save it. | 
| 492 | if accelerator.is_main_process: | 502 | if accelerator.is_main_process: | 
| 493 | print("Finished!") | 503 | print("Finished!") | 
| 494 | callbacks.on_checkpoint(global_step + global_step_offset, "end") | 504 | on_checkpoint(global_step + global_step_offset, "end") | 
| 495 | callbacks.on_sample(global_step + global_step_offset) | 505 | on_sample(global_step + global_step_offset) | 
| 496 | accelerator.end_training() | ||
| 497 | 506 | ||
| 498 | except KeyboardInterrupt: | 507 | except KeyboardInterrupt: | 
| 499 | if accelerator.is_main_process: | 508 | if accelerator.is_main_process: | 
| 500 | print("Interrupted") | 509 | print("Interrupted") | 
| 501 | callbacks.on_checkpoint(global_step + global_step_offset, "end") | 510 | on_checkpoint(global_step + global_step_offset, "end") | 
| 502 | accelerator.end_training() | ||
| 503 | |||
| 504 | |||
| 505 | def train( | ||
| 506 | accelerator: Accelerator, | ||
| 507 | unet: UNet2DConditionModel, | ||
| 508 | text_encoder: CLIPTextModel, | ||
| 509 | vae: AutoencoderKL, | ||
| 510 | noise_scheduler: DDPMScheduler, | ||
| 511 | train_dataloader: DataLoader, | ||
| 512 | val_dataloader: DataLoader, | ||
| 513 | dtype: torch.dtype, | ||
| 514 | seed: int, | ||
| 515 | optimizer: torch.optim.Optimizer, | ||
| 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 517 | num_train_epochs: int = 100, | ||
| 518 | sample_frequency: int = 20, | ||
| 519 | checkpoint_frequency: int = 50, | ||
| 520 | global_step_offset: int = 0, | ||
| 521 | prior_loss_weight: float = 0, | ||
| 522 | callbacks: TrainingCallbacks = TrainingCallbacks(), | ||
| 523 | ): | ||
| 524 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 525 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 526 | ) | ||
| 527 | |||
| 528 | vae.to(accelerator.device, dtype=dtype) | ||
| 529 | |||
| 530 | for model in (unet, text_encoder, vae): | ||
| 531 | model.requires_grad_(False) | ||
| 532 | model.eval() | ||
| 533 | |||
| 534 | callbacks.on_prepare() | ||
| 535 | |||
| 536 | loss_step_ = partial( | ||
| 537 | loss_step, | ||
| 538 | vae, | ||
| 539 | noise_scheduler, | ||
| 540 | unet, | ||
| 541 | text_encoder, | ||
| 542 | prior_loss_weight, | ||
| 543 | seed, | ||
| 544 | ) | ||
| 545 | |||
| 546 | if accelerator.is_main_process: | ||
| 547 | accelerator.init_trackers("textual_inversion") | ||
| 548 | |||
| 549 | train_loop( | ||
| 550 | accelerator=accelerator, | ||
| 551 | optimizer=optimizer, | ||
| 552 | lr_scheduler=lr_scheduler, | ||
| 553 | model=text_encoder, | ||
| 554 | train_dataloader=train_dataloader, | ||
| 555 | val_dataloader=val_dataloader, | ||
| 556 | loss_step=loss_step_, | ||
| 557 | sample_frequency=sample_frequency, | ||
| 558 | checkpoint_frequency=checkpoint_frequency, | ||
| 559 | global_step_offset=global_step_offset, | ||
| 560 | num_epochs=num_train_epochs, | ||
| 561 | callbacks=callbacks, | ||
| 562 | ) | ||
| 563 | |||
| 564 | accelerator.free_memory() | ||
diff --git a/training/lr.py b/training/lr.py index 7584ba2..902c4eb 100644 --- a/training/lr.py +++ b/training/lr.py  | |||
| @@ -9,6 +9,7 @@ import torch | |||
| 9 | from torch.optim.lr_scheduler import LambdaLR | 9 | from torch.optim.lr_scheduler import LambdaLR | 
| 10 | from tqdm.auto import tqdm | 10 | from tqdm.auto import tqdm | 
| 11 | 11 | ||
| 12 | from training.functional import TrainingCallbacks | ||
| 12 | from training.util import AverageMeter | 13 | from training.util import AverageMeter | 
| 13 | 14 | ||
| 14 | 15 | ||
| @@ -24,26 +25,19 @@ class LRFinder(): | |||
| 24 | def __init__( | 25 | def __init__( | 
| 25 | self, | 26 | self, | 
| 26 | accelerator, | 27 | accelerator, | 
| 27 | model, | ||
| 28 | optimizer, | 28 | optimizer, | 
| 29 | train_dataloader, | 29 | train_dataloader, | 
| 30 | val_dataloader, | 30 | val_dataloader, | 
| 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 31 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 
| 32 | on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, | 32 | callbacks: TrainingCallbacks = TrainingCallbacks() | 
| 33 | on_before_optimize: Callable[[int], None] = noop, | ||
| 34 | on_after_optimize: Callable[[float], None] = noop, | ||
| 35 | on_eval: Callable[[], _GeneratorContextManager] = noop_ctx | ||
| 36 | ): | 33 | ): | 
| 37 | self.accelerator = accelerator | 34 | self.accelerator = accelerator | 
| 38 | self.model = model | 35 | self.model = callbacks.on_model() | 
| 39 | self.optimizer = optimizer | 36 | self.optimizer = optimizer | 
| 40 | self.train_dataloader = train_dataloader | 37 | self.train_dataloader = train_dataloader | 
| 41 | self.val_dataloader = val_dataloader | 38 | self.val_dataloader = val_dataloader | 
| 42 | self.loss_fn = loss_fn | 39 | self.loss_fn = loss_fn | 
| 43 | self.on_train = on_train | 40 | self.callbacks = callbacks | 
| 44 | self.on_before_optimize = on_before_optimize | ||
| 45 | self.on_after_optimize = on_after_optimize | ||
| 46 | self.on_eval = on_eval | ||
| 47 | 41 | ||
| 48 | # self.model_state = copy.deepcopy(model.state_dict()) | 42 | # self.model_state = copy.deepcopy(model.state_dict()) | 
| 49 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 43 | # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) | 
| @@ -82,6 +76,13 @@ class LRFinder(): | |||
| 82 | ) | 76 | ) | 
| 83 | progress_bar.set_description("Epoch X / Y") | 77 | progress_bar.set_description("Epoch X / Y") | 
| 84 | 78 | ||
| 79 | self.callbacks.on_prepare() | ||
| 80 | |||
| 81 | on_train = self.callbacks.on_train | ||
| 82 | on_before_optimize = self.callbacks.on_before_optimize | ||
| 83 | on_after_optimize = self.callbacks.on_after_optimize | ||
| 84 | on_eval = self.callbacks.on_eval | ||
| 85 | |||
| 85 | for epoch in range(num_epochs): | 86 | for epoch in range(num_epochs): | 
| 86 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 87 | progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 
| 87 | 88 | ||
| @@ -90,7 +91,7 @@ class LRFinder(): | |||
| 90 | 91 | ||
| 91 | self.model.train() | 92 | self.model.train() | 
| 92 | 93 | ||
| 93 | with self.on_train(epoch): | 94 | with on_train(epoch): | 
| 94 | for step, batch in enumerate(self.train_dataloader): | 95 | for step, batch in enumerate(self.train_dataloader): | 
| 95 | if step >= num_train_batches: | 96 | if step >= num_train_batches: | 
| 96 | break | 97 | break | 
| @@ -100,21 +101,21 @@ class LRFinder(): | |||
| 100 | 101 | ||
| 101 | self.accelerator.backward(loss) | 102 | self.accelerator.backward(loss) | 
| 102 | 103 | ||
| 103 | self.on_before_optimize(epoch) | 104 | on_before_optimize(epoch) | 
| 104 | 105 | ||
| 105 | self.optimizer.step() | 106 | self.optimizer.step() | 
| 106 | lr_scheduler.step() | 107 | lr_scheduler.step() | 
| 107 | self.optimizer.zero_grad(set_to_none=True) | 108 | self.optimizer.zero_grad(set_to_none=True) | 
| 108 | 109 | ||
| 109 | if self.accelerator.sync_gradients: | 110 | if self.accelerator.sync_gradients: | 
| 110 | self.on_after_optimize(lr_scheduler.get_last_lr()[0]) | 111 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 
| 111 | 112 | ||
| 112 | progress_bar.update(1) | 113 | progress_bar.update(1) | 
| 113 | 114 | ||
| 114 | self.model.eval() | 115 | self.model.eval() | 
| 115 | 116 | ||
| 116 | with torch.inference_mode(): | 117 | with torch.inference_mode(): | 
| 117 | with self.on_eval(): | 118 | with on_eval(): | 
| 118 | for step, batch in enumerate(self.val_dataloader): | 119 | for step, batch in enumerate(self.val_dataloader): | 
| 119 | if step >= num_val_batches: | 120 | if step >= num_val_batches: | 
| 120 | break | 121 | break | 
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6f8384f..753dce0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py  | |||
| @@ -27,7 +27,6 @@ def textual_inversion_strategy( | |||
| 27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, | 
| 28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, | 
| 29 | val_dataloader: DataLoader, | 29 | val_dataloader: DataLoader, | 
| 30 | dtype: torch.dtype, | ||
| 31 | output_dir: Path, | 30 | output_dir: Path, | 
| 32 | seed: int, | 31 | seed: int, | 
| 33 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], | 
| @@ -48,6 +47,12 @@ def textual_inversion_strategy( | |||
| 48 | sample_guidance_scale: float = 7.5, | 47 | sample_guidance_scale: float = 7.5, | 
| 49 | sample_image_size: Optional[int] = None, | 48 | sample_image_size: Optional[int] = None, | 
| 50 | ): | 49 | ): | 
| 50 | weight_dtype = torch.float32 | ||
| 51 | if accelerator.state.mixed_precision == "fp16": | ||
| 52 | weight_dtype = torch.float16 | ||
| 53 | elif accelerator.state.mixed_precision == "bf16": | ||
| 54 | weight_dtype = torch.bfloat16 | ||
| 55 | |||
| 51 | save_samples_ = partial( | 56 | save_samples_ = partial( | 
| 52 | save_samples, | 57 | save_samples, | 
| 53 | accelerator=accelerator, | 58 | accelerator=accelerator, | 
| @@ -58,7 +63,7 @@ def textual_inversion_strategy( | |||
| 58 | sample_scheduler=sample_scheduler, | 63 | sample_scheduler=sample_scheduler, | 
| 59 | train_dataloader=train_dataloader, | 64 | train_dataloader=train_dataloader, | 
| 60 | val_dataloader=val_dataloader, | 65 | val_dataloader=val_dataloader, | 
| 61 | dtype=dtype, | 66 | dtype=weight_dtype, | 
| 62 | output_dir=output_dir, | 67 | output_dir=output_dir, | 
| 63 | seed=seed, | 68 | seed=seed, | 
| 64 | batch_size=sample_batch_size, | 69 | batch_size=sample_batch_size, | 
| @@ -78,6 +83,17 @@ def textual_inversion_strategy( | |||
| 78 | else: | 83 | else: | 
| 79 | ema_embeddings = None | 84 | ema_embeddings = None | 
| 80 | 85 | ||
| 86 | def ema_context(): | ||
| 87 | if use_ema: | ||
| 88 | return ema_embeddings.apply_temporary( | ||
| 89 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 90 | ) | ||
| 91 | else: | ||
| 92 | return nullcontext() | ||
| 93 | |||
| 94 | def on_model(): | ||
| 95 | return text_encoder | ||
| 96 | |||
| 81 | def on_prepare(): | 97 | def on_prepare(): | 
| 82 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 98 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 
| 83 | 99 | ||
| @@ -89,24 +105,15 @@ def textual_inversion_strategy( | |||
| 89 | 105 | ||
| 90 | @contextmanager | 106 | @contextmanager | 
| 91 | def on_train(epoch: int): | 107 | def on_train(epoch: int): | 
| 92 | try: | 108 | tokenizer.train() | 
| 93 | tokenizer.train() | 109 | yield | 
| 94 | yield | ||
| 95 | finally: | ||
| 96 | pass | ||
| 97 | 110 | ||
| 98 | @contextmanager | 111 | @contextmanager | 
| 99 | def on_eval(): | 112 | def on_eval(): | 
| 100 | try: | 113 | tokenizer.eval() | 
| 101 | tokenizer.eval() | ||
| 102 | 114 | ||
| 103 | ema_context = ema_embeddings.apply_temporary( | 115 | with ema_context(): | 
| 104 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() | 116 | yield | 
| 105 | |||
| 106 | with ema_context: | ||
| 107 | yield | ||
| 108 | finally: | ||
| 109 | pass | ||
| 110 | 117 | ||
| 111 | @torch.no_grad() | 118 | @torch.no_grad() | 
| 112 | def on_after_optimize(lr: float): | 119 | def on_after_optimize(lr: float): | 
| @@ -131,13 +138,7 @@ def textual_inversion_strategy( | |||
| 131 | checkpoints_path = output_dir.joinpath("checkpoints") | 138 | checkpoints_path = output_dir.joinpath("checkpoints") | 
| 132 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 139 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 
| 133 | 140 | ||
| 134 | text_encoder = accelerator.unwrap_model(text_encoder) | 141 | with ema_context(): | 
| 135 | |||
| 136 | ema_context = ema_embeddings.apply_temporary( | ||
| 137 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 138 | ) if ema_embeddings is not None else nullcontext() | ||
| 139 | |||
| 140 | with ema_context: | ||
| 141 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 142 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 
| 142 | text_encoder.text_model.embeddings.save_embed( | 143 | text_encoder.text_model.embeddings.save_embed( | 
| 143 | ids, | 144 | ids, | 
| @@ -146,15 +147,12 @@ def textual_inversion_strategy( | |||
| 146 | 147 | ||
| 147 | @torch.no_grad() | 148 | @torch.no_grad() | 
| 148 | def on_sample(step): | 149 | def on_sample(step): | 
| 149 | ema_context = ema_embeddings.apply_temporary( | 150 | with ema_context(): | 
| 150 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 151 | ) if ema_embeddings is not None else nullcontext() | ||
| 152 | |||
| 153 | with ema_context: | ||
| 154 | save_samples_(step=step) | 151 | save_samples_(step=step) | 
| 155 | 152 | ||
| 156 | return TrainingCallbacks( | 153 | return TrainingCallbacks( | 
| 157 | on_prepare=on_prepare, | 154 | on_prepare=on_prepare, | 
| 155 | on_model=on_model, | ||
| 158 | on_train=on_train, | 156 | on_train=on_train, | 
| 159 | on_eval=on_eval, | 157 | on_eval=on_eval, | 
| 160 | on_after_optimize=on_after_optimize, | 158 | on_after_optimize=on_after_optimize, | 
