diff options
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 135 | ||||
| -rw-r--r-- | train_ti.py | 219 | ||||
| -rw-r--r-- | training/functional.py | 17 |
4 files changed, 130 insertions, 243 deletions
diff --git a/data/csv.py b/data/csv.py index 85b98f8..6857b6f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -282,7 +282,7 @@ class VlpnDataModule(): | |||
| 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 283 | 283 | ||
| 284 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
| 285 | data_train, data_val = items, items[:1] | 285 | data_train, data_val = items, [] |
| 286 | else: | 286 | else: |
| 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
| 288 | 288 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1dc41b1..6511f9b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -200,23 +200,6 @@ def parse_args(): | |||
| 200 | default=100 | 200 | default=100 |
| 201 | ) | 201 | ) |
| 202 | parser.add_argument( | 202 | parser.add_argument( |
| 203 | "--ti_data_template", | ||
| 204 | type=str, | ||
| 205 | nargs='*', | ||
| 206 | default=[], | ||
| 207 | ) | ||
| 208 | parser.add_argument( | ||
| 209 | "--ti_num_train_epochs", | ||
| 210 | type=int, | ||
| 211 | default=10 | ||
| 212 | ) | ||
| 213 | parser.add_argument( | ||
| 214 | "--ti_batch_size", | ||
| 215 | type=int, | ||
| 216 | default=1, | ||
| 217 | help="Batch size (per device) for the training dataloader." | ||
| 218 | ) | ||
| 219 | parser.add_argument( | ||
| 220 | "--max_train_steps", | 203 | "--max_train_steps", |
| 221 | type=int, | 204 | type=int, |
| 222 | default=None, | 205 | default=None, |
| @@ -245,12 +228,6 @@ def parse_args(): | |||
| 245 | help="Initial learning rate (after the potential warmup period) to use.", | 228 | help="Initial learning rate (after the potential warmup period) to use.", |
| 246 | ) | 229 | ) |
| 247 | parser.add_argument( | 230 | parser.add_argument( |
| 248 | "--ti_learning_rate", | ||
| 249 | type=float, | ||
| 250 | default=1e-2, | ||
| 251 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 252 | ) | ||
| 253 | parser.add_argument( | ||
| 254 | "--scale_lr", | 231 | "--scale_lr", |
| 255 | action="store_true", | 232 | action="store_true", |
| 256 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 233 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -482,12 +459,6 @@ def parse_args(): | |||
| 482 | if len(args.placeholder_tokens) != len(args.num_vectors): | 459 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 460 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 484 | 461 | ||
| 485 | if isinstance(args.ti_data_template, str): | ||
| 486 | args.ti_data_template = [args.ti_data_template] | ||
| 487 | |||
| 488 | if len(args.ti_data_template) == 0: | ||
| 489 | raise ValueError("You must specify --ti_data_template") | ||
| 490 | |||
| 491 | if isinstance(args.collection, str): | 462 | if isinstance(args.collection, str): |
| 492 | args.collection = [args.collection] | 463 | args.collection = [args.collection] |
| 493 | 464 | ||
| @@ -521,8 +492,6 @@ def main(): | |||
| 521 | 492 | ||
| 522 | set_seed(args.seed) | 493 | set_seed(args.seed) |
| 523 | 494 | ||
| 524 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
| 525 | |||
| 526 | save_args(output_dir, args) | 495 | save_args(output_dir, args) |
| 527 | 496 | ||
| 528 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 497 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| @@ -583,107 +552,6 @@ def main(): | |||
| 583 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
| 584 | ) | 553 | ) |
| 585 | 554 | ||
| 586 | # Initial TI | ||
| 587 | |||
| 588 | print("Phase 1: Textual Inversion") | ||
| 589 | |||
| 590 | cur_dir = output_dir.joinpath("1-ti") | ||
| 591 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
| 592 | |||
| 593 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | ||
| 594 | range(len(args.placeholder_tokens)), | ||
| 595 | args.placeholder_tokens, | ||
| 596 | args.initializer_tokens, | ||
| 597 | args.num_vectors, | ||
| 598 | args.ti_data_template | ||
| 599 | ): | ||
| 600 | cur_subdir = cur_dir.joinpath(placeholder_token) | ||
| 601 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
| 602 | |||
| 603 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 604 | tokenizer=tokenizer, | ||
| 605 | embeddings=embeddings, | ||
| 606 | placeholder_tokens=[placeholder_token], | ||
| 607 | initializer_tokens=[initializer_token], | ||
| 608 | num_vectors=[num_vectors] | ||
| 609 | ) | ||
| 610 | |||
| 611 | print( | ||
| 612 | f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | ||
| 613 | |||
| 614 | args.seed = seed_generator.seed() | ||
| 615 | |||
| 616 | datamodule = VlpnDataModule( | ||
| 617 | data_file=args.train_data_file, | ||
| 618 | batch_size=args.ti_batch_size, | ||
| 619 | tokenizer=tokenizer, | ||
| 620 | class_subdir=args.class_image_dir, | ||
| 621 | num_class_images=args.num_class_images, | ||
| 622 | size=args.resolution, | ||
| 623 | shuffle=not args.no_tag_shuffle, | ||
| 624 | template_key=data_template, | ||
| 625 | valid_set_size=1, | ||
| 626 | train_set_pad=args.train_set_pad, | ||
| 627 | valid_set_pad=args.valid_set_pad, | ||
| 628 | seed=args.seed, | ||
| 629 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | ||
| 630 | dtype=weight_dtype | ||
| 631 | ) | ||
| 632 | datamodule.setup() | ||
| 633 | |||
| 634 | optimizer = optimizer_class( | ||
| 635 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 636 | lr=args.ti_learning_rate, | ||
| 637 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 638 | weight_decay=0.0, | ||
| 639 | eps=args.adam_epsilon, | ||
| 640 | ) | ||
| 641 | |||
| 642 | lr_scheduler = get_scheduler( | ||
| 643 | "one_cycle", | ||
| 644 | optimizer=optimizer, | ||
| 645 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 646 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 647 | train_epochs=args.ti_num_train_epochs, | ||
| 648 | ) | ||
| 649 | |||
| 650 | trainer( | ||
| 651 | callbacks_fn=textual_inversion_strategy, | ||
| 652 | project="textual_inversion", | ||
| 653 | train_dataloader=datamodule.train_dataloader, | ||
| 654 | val_dataloader=datamodule.val_dataloader, | ||
| 655 | seed=args.seed, | ||
| 656 | optimizer=optimizer, | ||
| 657 | lr_scheduler=lr_scheduler, | ||
| 658 | num_train_epochs=args.ti_num_train_epochs, | ||
| 659 | sample_frequency=args.ti_num_train_epochs // 5, | ||
| 660 | checkpoint_frequency=9999999, | ||
| 661 | # -- | ||
| 662 | tokenizer=tokenizer, | ||
| 663 | sample_scheduler=sample_scheduler, | ||
| 664 | output_dir=cur_subdir, | ||
| 665 | placeholder_tokens=[placeholder_token], | ||
| 666 | placeholder_token_ids=placeholder_token_ids, | ||
| 667 | learning_rate=args.ti_learning_rate, | ||
| 668 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 669 | use_emb_decay=True, | ||
| 670 | sample_batch_size=args.sample_batch_size, | ||
| 671 | sample_num_batches=args.sample_batches, | ||
| 672 | sample_num_steps=args.sample_steps, | ||
| 673 | sample_image_size=args.sample_image_size, | ||
| 674 | ) | ||
| 675 | |||
| 676 | embeddings.persist() | ||
| 677 | |||
| 678 | # Dreambooth | ||
| 679 | |||
| 680 | print("Phase 2: Dreambooth") | ||
| 681 | |||
| 682 | cur_dir = output_dir.joinpath("2-db") | ||
| 683 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
| 684 | |||
| 685 | args.seed = seed_generator.seed() | ||
| 686 | |||
| 687 | datamodule = VlpnDataModule( | 555 | datamodule = VlpnDataModule( |
| 688 | data_file=args.train_data_file, | 556 | data_file=args.train_data_file, |
| 689 | batch_size=args.train_batch_size, | 557 | batch_size=args.train_batch_size, |
| @@ -746,12 +614,13 @@ def main(): | |||
| 746 | seed=args.seed, | 614 | seed=args.seed, |
| 747 | optimizer=optimizer, | 615 | optimizer=optimizer, |
| 748 | lr_scheduler=lr_scheduler, | 616 | lr_scheduler=lr_scheduler, |
| 617 | prepare_unet=True, | ||
| 749 | num_train_epochs=args.num_train_epochs, | 618 | num_train_epochs=args.num_train_epochs, |
| 750 | sample_frequency=args.sample_frequency, | 619 | sample_frequency=args.sample_frequency, |
| 751 | # -- | 620 | # -- |
| 752 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
| 753 | sample_scheduler=sample_scheduler, | 622 | sample_scheduler=sample_scheduler, |
| 754 | output_dir=cur_dir, | 623 | output_dir=output_dir, |
| 755 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 624 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
| 756 | max_grad_norm=args.max_grad_norm, | 625 | max_grad_norm=args.max_grad_norm, |
| 757 | use_ema=args.use_ema, | 626 | use_ema=args.use_ema, |
diff --git a/train_ti.py b/train_ti.py index 7aecdef..adba8d4 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -51,6 +51,7 @@ def parse_args(): | |||
| 51 | parser.add_argument( | 51 | parser.add_argument( |
| 52 | "--train_data_template", | 52 | "--train_data_template", |
| 53 | type=str, | 53 | type=str, |
| 54 | nargs='*', | ||
| 54 | default="template", | 55 | default="template", |
| 55 | ) | 56 | ) |
| 56 | parser.add_argument( | 57 | parser.add_argument( |
| @@ -468,11 +469,17 @@ def parse_args(): | |||
| 468 | args.num_vectors = 1 | 469 | args.num_vectors = 1 |
| 469 | 470 | ||
| 470 | if isinstance(args.num_vectors, int): | 471 | if isinstance(args.num_vectors, int): |
| 471 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | 472 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 472 | 473 | ||
| 473 | if len(args.placeholder_tokens) != len(args.num_vectors): | 474 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 474 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 475 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 475 | 476 | ||
| 477 | if isinstance(args.train_data_template, str): | ||
| 478 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | ||
| 479 | |||
| 480 | if len(args.placeholder_tokens) != len(args.train_data_template): | ||
| 481 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | ||
| 482 | |||
| 476 | if isinstance(args.collection, str): | 483 | if isinstance(args.collection, str): |
| 477 | args.collection = [args.collection] | 484 | args.collection = [args.collection] |
| 478 | 485 | ||
| @@ -507,6 +514,8 @@ def main(): | |||
| 507 | 514 | ||
| 508 | set_seed(args.seed) | 515 | set_seed(args.seed) |
| 509 | 516 | ||
| 517 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
| 518 | |||
| 510 | save_args(output_dir, args) | 519 | save_args(output_dir, args) |
| 511 | 520 | ||
| 512 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 521 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| @@ -531,19 +540,6 @@ def main(): | |||
| 531 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 540 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 532 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 541 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 533 | 542 | ||
| 534 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 535 | tokenizer=tokenizer, | ||
| 536 | embeddings=embeddings, | ||
| 537 | placeholder_tokens=args.placeholder_tokens, | ||
| 538 | initializer_tokens=args.initializer_tokens, | ||
| 539 | num_vectors=args.num_vectors | ||
| 540 | ) | ||
| 541 | |||
| 542 | if len(placeholder_token_ids) != 0: | ||
| 543 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] | ||
| 544 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 545 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 546 | |||
| 547 | if args.scale_lr: | 543 | if args.scale_lr: |
| 548 | args.learning_rate = ( | 544 | args.learning_rate = ( |
| 549 | args.learning_rate * args.gradient_accumulation_steps * | 545 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -566,43 +562,6 @@ def main(): | |||
| 566 | elif args.mixed_precision == "bf16": | 562 | elif args.mixed_precision == "bf16": |
| 567 | weight_dtype = torch.bfloat16 | 563 | weight_dtype = torch.bfloat16 |
| 568 | 564 | ||
| 569 | datamodule = VlpnDataModule( | ||
| 570 | data_file=args.train_data_file, | ||
| 571 | batch_size=args.train_batch_size, | ||
| 572 | tokenizer=tokenizer, | ||
| 573 | class_subdir=args.class_image_dir, | ||
| 574 | num_class_images=args.num_class_images, | ||
| 575 | size=args.resolution, | ||
| 576 | num_buckets=args.num_buckets, | ||
| 577 | progressive_buckets=args.progressive_buckets, | ||
| 578 | bucket_step_size=args.bucket_step_size, | ||
| 579 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 580 | dropout=args.tag_dropout, | ||
| 581 | shuffle=not args.no_tag_shuffle, | ||
| 582 | template_key=args.train_data_template, | ||
| 583 | valid_set_size=args.valid_set_size, | ||
| 584 | train_set_pad=args.train_set_pad, | ||
| 585 | valid_set_pad=args.valid_set_pad, | ||
| 586 | seed=args.seed, | ||
| 587 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
| 588 | dtype=weight_dtype | ||
| 589 | ) | ||
| 590 | datamodule.setup() | ||
| 591 | |||
| 592 | if args.num_class_images != 0: | ||
| 593 | generate_class_images( | ||
| 594 | accelerator, | ||
| 595 | text_encoder, | ||
| 596 | vae, | ||
| 597 | unet, | ||
| 598 | tokenizer, | ||
| 599 | sample_scheduler, | ||
| 600 | datamodule.train_dataset, | ||
| 601 | args.sample_batch_size, | ||
| 602 | args.sample_image_size, | ||
| 603 | args.sample_steps | ||
| 604 | ) | ||
| 605 | |||
| 606 | trainer = partial( | 565 | trainer = partial( |
| 607 | train, | 566 | train, |
| 608 | accelerator=accelerator, | 567 | accelerator=accelerator, |
| @@ -615,63 +574,111 @@ def main(): | |||
| 615 | callbacks_fn=textual_inversion_strategy | 574 | callbacks_fn=textual_inversion_strategy |
| 616 | ) | 575 | ) |
| 617 | 576 | ||
| 618 | optimizer = optimizer_class( | 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
| 619 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 578 | range(len(args.placeholder_tokens)), |
| 620 | lr=args.learning_rate, | 579 | args.placeholder_tokens, |
| 621 | betas=(args.adam_beta1, args.adam_beta2), | 580 | args.initializer_tokens, |
| 622 | weight_decay=args.adam_weight_decay, | 581 | args.num_vectors, |
| 623 | eps=args.adam_epsilon, | 582 | args.train_data_template |
| 624 | amsgrad=args.adam_amsgrad, | 583 | ): |
| 625 | ) | 584 | cur_subdir = output_dir.joinpath(placeholder_token) |
| 585 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
| 626 | 586 | ||
| 627 | lr_scheduler = get_scheduler( | 587 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 628 | args.lr_scheduler, | 588 | tokenizer=tokenizer, |
| 629 | optimizer=optimizer, | 589 | embeddings=embeddings, |
| 630 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 590 | placeholder_tokens=[placeholder_token], |
| 631 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 591 | initializer_tokens=[initializer_token], |
| 632 | min_lr=args.lr_min_lr, | 592 | num_vectors=[num_vectors] |
| 633 | warmup_func=args.lr_warmup_func, | 593 | ) |
| 634 | annealing_func=args.lr_annealing_func, | ||
| 635 | warmup_exp=args.lr_warmup_exp, | ||
| 636 | annealing_exp=args.lr_annealing_exp, | ||
| 637 | cycles=args.lr_cycles, | ||
| 638 | train_epochs=args.num_train_epochs, | ||
| 639 | warmup_epochs=args.lr_warmup_epochs, | ||
| 640 | ) | ||
| 641 | 594 | ||
| 642 | trainer( | 595 | print( |
| 643 | project="textual_inversion", | 596 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") |
| 644 | train_dataloader=datamodule.train_dataloader, | 597 | |
| 645 | val_dataloader=datamodule.val_dataloader, | 598 | args.seed = seed_generator.seed() |
| 646 | optimizer=optimizer, | 599 | |
| 647 | lr_scheduler=lr_scheduler, | 600 | datamodule = VlpnDataModule( |
| 648 | num_train_epochs=args.num_train_epochs, | 601 | data_file=args.train_data_file, |
| 649 | sample_frequency=args.sample_frequency, | 602 | batch_size=args.train_batch_size, |
| 650 | checkpoint_frequency=args.checkpoint_frequency, | 603 | tokenizer=tokenizer, |
| 651 | global_step_offset=global_step_offset, | 604 | class_subdir=args.class_image_dir, |
| 652 | with_prior_preservation=args.num_class_images != 0, | 605 | num_class_images=args.num_class_images, |
| 653 | prior_loss_weight=args.prior_loss_weight, | 606 | size=args.resolution, |
| 654 | # -- | 607 | num_buckets=args.num_buckets, |
| 655 | tokenizer=tokenizer, | 608 | progressive_buckets=args.progressive_buckets, |
| 656 | sample_scheduler=sample_scheduler, | 609 | bucket_step_size=args.bucket_step_size, |
| 657 | output_dir=output_dir, | 610 | bucket_max_pixels=args.bucket_max_pixels, |
| 658 | placeholder_tokens=args.placeholder_tokens, | 611 | dropout=args.tag_dropout, |
| 659 | placeholder_token_ids=placeholder_token_ids, | 612 | shuffle=not args.no_tag_shuffle, |
| 660 | learning_rate=args.learning_rate, | 613 | template_key=data_template, |
| 661 | gradient_checkpointing=args.gradient_checkpointing, | 614 | valid_set_size=args.valid_set_size, |
| 662 | use_emb_decay=args.use_emb_decay, | 615 | train_set_pad=args.train_set_pad, |
| 663 | emb_decay_target=args.emb_decay_target, | 616 | valid_set_pad=args.valid_set_pad, |
| 664 | emb_decay_factor=args.emb_decay_factor, | 617 | seed=args.seed, |
| 665 | emb_decay_start=args.emb_decay_start, | 618 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
| 666 | use_ema=args.use_ema, | 619 | dtype=weight_dtype |
| 667 | ema_inv_gamma=args.ema_inv_gamma, | 620 | ) |
| 668 | ema_power=args.ema_power, | 621 | datamodule.setup() |
| 669 | ema_max_decay=args.ema_max_decay, | 622 | |
| 670 | sample_batch_size=args.sample_batch_size, | 623 | optimizer = optimizer_class( |
| 671 | sample_num_batches=args.sample_batches, | 624 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 672 | sample_num_steps=args.sample_steps, | 625 | lr=args.learning_rate, |
| 673 | sample_image_size=args.sample_image_size, | 626 | betas=(args.adam_beta1, args.adam_beta2), |
| 674 | ) | 627 | weight_decay=args.adam_weight_decay, |
| 628 | eps=args.adam_epsilon, | ||
| 629 | amsgrad=args.adam_amsgrad, | ||
| 630 | ) | ||
| 631 | |||
| 632 | lr_scheduler = get_scheduler( | ||
| 633 | args.lr_scheduler, | ||
| 634 | optimizer=optimizer, | ||
| 635 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 636 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 637 | min_lr=args.lr_min_lr, | ||
| 638 | warmup_func=args.lr_warmup_func, | ||
| 639 | annealing_func=args.lr_annealing_func, | ||
| 640 | warmup_exp=args.lr_warmup_exp, | ||
| 641 | annealing_exp=args.lr_annealing_exp, | ||
| 642 | cycles=args.lr_cycles, | ||
| 643 | train_epochs=args.num_train_epochs, | ||
| 644 | warmup_epochs=args.lr_warmup_epochs, | ||
| 645 | ) | ||
| 646 | |||
| 647 | trainer( | ||
| 648 | project="textual_inversion", | ||
| 649 | train_dataloader=datamodule.train_dataloader, | ||
| 650 | val_dataloader=datamodule.val_dataloader, | ||
| 651 | optimizer=optimizer, | ||
| 652 | lr_scheduler=lr_scheduler, | ||
| 653 | num_train_epochs=args.num_train_epochs, | ||
| 654 | sample_frequency=args.sample_frequency, | ||
| 655 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 656 | global_step_offset=global_step_offset, | ||
| 657 | with_prior_preservation=args.num_class_images != 0, | ||
| 658 | prior_loss_weight=args.prior_loss_weight, | ||
| 659 | # -- | ||
| 660 | tokenizer=tokenizer, | ||
| 661 | sample_scheduler=sample_scheduler, | ||
| 662 | output_dir=cur_subdir, | ||
| 663 | placeholder_tokens=[placeholder_token], | ||
| 664 | placeholder_token_ids=placeholder_token_ids, | ||
| 665 | learning_rate=args.learning_rate, | ||
| 666 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 667 | use_emb_decay=args.use_emb_decay, | ||
| 668 | emb_decay_target=args.emb_decay_target, | ||
| 669 | emb_decay_factor=args.emb_decay_factor, | ||
| 670 | emb_decay_start=args.emb_decay_start, | ||
| 671 | use_ema=args.use_ema, | ||
| 672 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 673 | ema_power=args.ema_power, | ||
| 674 | ema_max_decay=args.ema_max_decay, | ||
| 675 | sample_batch_size=args.sample_batch_size, | ||
| 676 | sample_num_batches=args.sample_batches, | ||
| 677 | sample_num_steps=args.sample_steps, | ||
| 678 | sample_image_size=args.sample_image_size, | ||
| 679 | ) | ||
| 680 | |||
| 681 | embeddings.persist() | ||
| 675 | 682 | ||
| 676 | 683 | ||
| 677 | if __name__ == "__main__": | 684 | if __name__ == "__main__": |
diff --git a/training/functional.py b/training/functional.py index b6b5d87..1548784 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -515,6 +515,7 @@ def train( | |||
| 515 | optimizer: torch.optim.Optimizer, | 515 | optimizer: torch.optim.Optimizer, |
| 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 517 | callbacks_fn: Callable[..., TrainingCallbacks], | 517 | callbacks_fn: Callable[..., TrainingCallbacks], |
| 518 | prepare_unet: bool = False, | ||
| 518 | num_train_epochs: int = 100, | 519 | num_train_epochs: int = 100, |
| 519 | sample_frequency: int = 20, | 520 | sample_frequency: int = 20, |
| 520 | checkpoint_frequency: int = 50, | 521 | checkpoint_frequency: int = 50, |
| @@ -523,9 +524,19 @@ def train( | |||
| 523 | prior_loss_weight: float = 1.0, | 524 | prior_loss_weight: float = 1.0, |
| 524 | **kwargs, | 525 | **kwargs, |
| 525 | ): | 526 | ): |
| 526 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 527 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
| 527 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 528 | |
| 528 | ) | 529 | if prepare_unet: |
| 530 | prep.append(unet) | ||
| 531 | |||
| 532 | prep = accelerator.prepare(*prep) | ||
| 533 | |||
| 534 | if prepare_unet: | ||
| 535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
| 536 | else: | ||
| 537 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
| 538 | |||
| 539 | unet.to(accelerator.device, dtype=dtype) | ||
| 529 | 540 | ||
| 530 | vae.to(accelerator.device, dtype=dtype) | 541 | vae.to(accelerator.device, dtype=dtype) |
| 531 | 542 | ||
