diff options
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | train_dreambooth.py | 135 | ||||
-rw-r--r-- | train_ti.py | 221 | ||||
-rw-r--r-- | training/functional.py | 17 |
4 files changed, 131 insertions, 244 deletions
diff --git a/data/csv.py b/data/csv.py index 85b98f8..6857b6f 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -282,7 +282,7 @@ class VlpnDataModule(): | |||
282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
283 | 283 | ||
284 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
285 | data_train, data_val = items, items[:1] | 285 | data_train, data_val = items, [] |
286 | else: | 286 | else: |
287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
288 | 288 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1dc41b1..6511f9b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -200,23 +200,6 @@ def parse_args(): | |||
200 | default=100 | 200 | default=100 |
201 | ) | 201 | ) |
202 | parser.add_argument( | 202 | parser.add_argument( |
203 | "--ti_data_template", | ||
204 | type=str, | ||
205 | nargs='*', | ||
206 | default=[], | ||
207 | ) | ||
208 | parser.add_argument( | ||
209 | "--ti_num_train_epochs", | ||
210 | type=int, | ||
211 | default=10 | ||
212 | ) | ||
213 | parser.add_argument( | ||
214 | "--ti_batch_size", | ||
215 | type=int, | ||
216 | default=1, | ||
217 | help="Batch size (per device) for the training dataloader." | ||
218 | ) | ||
219 | parser.add_argument( | ||
220 | "--max_train_steps", | 203 | "--max_train_steps", |
221 | type=int, | 204 | type=int, |
222 | default=None, | 205 | default=None, |
@@ -245,12 +228,6 @@ def parse_args(): | |||
245 | help="Initial learning rate (after the potential warmup period) to use.", | 228 | help="Initial learning rate (after the potential warmup period) to use.", |
246 | ) | 229 | ) |
247 | parser.add_argument( | 230 | parser.add_argument( |
248 | "--ti_learning_rate", | ||
249 | type=float, | ||
250 | default=1e-2, | ||
251 | help="Initial learning rate (after the potential warmup period) to use.", | ||
252 | ) | ||
253 | parser.add_argument( | ||
254 | "--scale_lr", | 231 | "--scale_lr", |
255 | action="store_true", | 232 | action="store_true", |
256 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 233 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
@@ -482,12 +459,6 @@ def parse_args(): | |||
482 | if len(args.placeholder_tokens) != len(args.num_vectors): | 459 | if len(args.placeholder_tokens) != len(args.num_vectors): |
483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 460 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
484 | 461 | ||
485 | if isinstance(args.ti_data_template, str): | ||
486 | args.ti_data_template = [args.ti_data_template] | ||
487 | |||
488 | if len(args.ti_data_template) == 0: | ||
489 | raise ValueError("You must specify --ti_data_template") | ||
490 | |||
491 | if isinstance(args.collection, str): | 462 | if isinstance(args.collection, str): |
492 | args.collection = [args.collection] | 463 | args.collection = [args.collection] |
493 | 464 | ||
@@ -521,8 +492,6 @@ def main(): | |||
521 | 492 | ||
522 | set_seed(args.seed) | 493 | set_seed(args.seed) |
523 | 494 | ||
524 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
525 | |||
526 | save_args(output_dir, args) | 495 | save_args(output_dir, args) |
527 | 496 | ||
528 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 497 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
@@ -583,107 +552,6 @@ def main(): | |||
583 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
584 | ) | 553 | ) |
585 | 554 | ||
586 | # Initial TI | ||
587 | |||
588 | print("Phase 1: Textual Inversion") | ||
589 | |||
590 | cur_dir = output_dir.joinpath("1-ti") | ||
591 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
592 | |||
593 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | ||
594 | range(len(args.placeholder_tokens)), | ||
595 | args.placeholder_tokens, | ||
596 | args.initializer_tokens, | ||
597 | args.num_vectors, | ||
598 | args.ti_data_template | ||
599 | ): | ||
600 | cur_subdir = cur_dir.joinpath(placeholder_token) | ||
601 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
602 | |||
603 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
604 | tokenizer=tokenizer, | ||
605 | embeddings=embeddings, | ||
606 | placeholder_tokens=[placeholder_token], | ||
607 | initializer_tokens=[initializer_token], | ||
608 | num_vectors=[num_vectors] | ||
609 | ) | ||
610 | |||
611 | print( | ||
612 | f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | ||
613 | |||
614 | args.seed = seed_generator.seed() | ||
615 | |||
616 | datamodule = VlpnDataModule( | ||
617 | data_file=args.train_data_file, | ||
618 | batch_size=args.ti_batch_size, | ||
619 | tokenizer=tokenizer, | ||
620 | class_subdir=args.class_image_dir, | ||
621 | num_class_images=args.num_class_images, | ||
622 | size=args.resolution, | ||
623 | shuffle=not args.no_tag_shuffle, | ||
624 | template_key=data_template, | ||
625 | valid_set_size=1, | ||
626 | train_set_pad=args.train_set_pad, | ||
627 | valid_set_pad=args.valid_set_pad, | ||
628 | seed=args.seed, | ||
629 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | ||
630 | dtype=weight_dtype | ||
631 | ) | ||
632 | datamodule.setup() | ||
633 | |||
634 | optimizer = optimizer_class( | ||
635 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
636 | lr=args.ti_learning_rate, | ||
637 | betas=(args.adam_beta1, args.adam_beta2), | ||
638 | weight_decay=0.0, | ||
639 | eps=args.adam_epsilon, | ||
640 | ) | ||
641 | |||
642 | lr_scheduler = get_scheduler( | ||
643 | "one_cycle", | ||
644 | optimizer=optimizer, | ||
645 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
646 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
647 | train_epochs=args.ti_num_train_epochs, | ||
648 | ) | ||
649 | |||
650 | trainer( | ||
651 | callbacks_fn=textual_inversion_strategy, | ||
652 | project="textual_inversion", | ||
653 | train_dataloader=datamodule.train_dataloader, | ||
654 | val_dataloader=datamodule.val_dataloader, | ||
655 | seed=args.seed, | ||
656 | optimizer=optimizer, | ||
657 | lr_scheduler=lr_scheduler, | ||
658 | num_train_epochs=args.ti_num_train_epochs, | ||
659 | sample_frequency=args.ti_num_train_epochs // 5, | ||
660 | checkpoint_frequency=9999999, | ||
661 | # -- | ||
662 | tokenizer=tokenizer, | ||
663 | sample_scheduler=sample_scheduler, | ||
664 | output_dir=cur_subdir, | ||
665 | placeholder_tokens=[placeholder_token], | ||
666 | placeholder_token_ids=placeholder_token_ids, | ||
667 | learning_rate=args.ti_learning_rate, | ||
668 | gradient_checkpointing=args.gradient_checkpointing, | ||
669 | use_emb_decay=True, | ||
670 | sample_batch_size=args.sample_batch_size, | ||
671 | sample_num_batches=args.sample_batches, | ||
672 | sample_num_steps=args.sample_steps, | ||
673 | sample_image_size=args.sample_image_size, | ||
674 | ) | ||
675 | |||
676 | embeddings.persist() | ||
677 | |||
678 | # Dreambooth | ||
679 | |||
680 | print("Phase 2: Dreambooth") | ||
681 | |||
682 | cur_dir = output_dir.joinpath("2-db") | ||
683 | cur_dir.mkdir(parents=True, exist_ok=True) | ||
684 | |||
685 | args.seed = seed_generator.seed() | ||
686 | |||
687 | datamodule = VlpnDataModule( | 555 | datamodule = VlpnDataModule( |
688 | data_file=args.train_data_file, | 556 | data_file=args.train_data_file, |
689 | batch_size=args.train_batch_size, | 557 | batch_size=args.train_batch_size, |
@@ -746,12 +614,13 @@ def main(): | |||
746 | seed=args.seed, | 614 | seed=args.seed, |
747 | optimizer=optimizer, | 615 | optimizer=optimizer, |
748 | lr_scheduler=lr_scheduler, | 616 | lr_scheduler=lr_scheduler, |
617 | prepare_unet=True, | ||
749 | num_train_epochs=args.num_train_epochs, | 618 | num_train_epochs=args.num_train_epochs, |
750 | sample_frequency=args.sample_frequency, | 619 | sample_frequency=args.sample_frequency, |
751 | # -- | 620 | # -- |
752 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
753 | sample_scheduler=sample_scheduler, | 622 | sample_scheduler=sample_scheduler, |
754 | output_dir=cur_dir, | 623 | output_dir=output_dir, |
755 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 624 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
756 | max_grad_norm=args.max_grad_norm, | 625 | max_grad_norm=args.max_grad_norm, |
757 | use_ema=args.use_ema, | 626 | use_ema=args.use_ema, |
diff --git a/train_ti.py b/train_ti.py index 7aecdef..adba8d4 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -51,6 +51,7 @@ def parse_args(): | |||
51 | parser.add_argument( | 51 | parser.add_argument( |
52 | "--train_data_template", | 52 | "--train_data_template", |
53 | type=str, | 53 | type=str, |
54 | nargs='*', | ||
54 | default="template", | 55 | default="template", |
55 | ) | 56 | ) |
56 | parser.add_argument( | 57 | parser.add_argument( |
@@ -468,11 +469,17 @@ def parse_args(): | |||
468 | args.num_vectors = 1 | 469 | args.num_vectors = 1 |
469 | 470 | ||
470 | if isinstance(args.num_vectors, int): | 471 | if isinstance(args.num_vectors, int): |
471 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | 472 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
472 | 473 | ||
473 | if len(args.placeholder_tokens) != len(args.num_vectors): | 474 | if len(args.placeholder_tokens) != len(args.num_vectors): |
474 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 475 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
475 | 476 | ||
477 | if isinstance(args.train_data_template, str): | ||
478 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | ||
479 | |||
480 | if len(args.placeholder_tokens) != len(args.train_data_template): | ||
481 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | ||
482 | |||
476 | if isinstance(args.collection, str): | 483 | if isinstance(args.collection, str): |
477 | args.collection = [args.collection] | 484 | args.collection = [args.collection] |
478 | 485 | ||
@@ -507,6 +514,8 @@ def main(): | |||
507 | 514 | ||
508 | set_seed(args.seed) | 515 | set_seed(args.seed) |
509 | 516 | ||
517 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
518 | |||
510 | save_args(output_dir, args) | 519 | save_args(output_dir, args) |
511 | 520 | ||
512 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 521 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
@@ -531,19 +540,6 @@ def main(): | |||
531 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 540 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
532 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 541 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
533 | 542 | ||
534 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
535 | tokenizer=tokenizer, | ||
536 | embeddings=embeddings, | ||
537 | placeholder_tokens=args.placeholder_tokens, | ||
538 | initializer_tokens=args.initializer_tokens, | ||
539 | num_vectors=args.num_vectors | ||
540 | ) | ||
541 | |||
542 | if len(placeholder_token_ids) != 0: | ||
543 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] | ||
544 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
545 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
546 | |||
547 | if args.scale_lr: | 543 | if args.scale_lr: |
548 | args.learning_rate = ( | 544 | args.learning_rate = ( |
549 | args.learning_rate * args.gradient_accumulation_steps * | 545 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -566,43 +562,6 @@ def main(): | |||
566 | elif args.mixed_precision == "bf16": | 562 | elif args.mixed_precision == "bf16": |
567 | weight_dtype = torch.bfloat16 | 563 | weight_dtype = torch.bfloat16 |
568 | 564 | ||
569 | datamodule = VlpnDataModule( | ||
570 | data_file=args.train_data_file, | ||
571 | batch_size=args.train_batch_size, | ||
572 | tokenizer=tokenizer, | ||
573 | class_subdir=args.class_image_dir, | ||
574 | num_class_images=args.num_class_images, | ||
575 | size=args.resolution, | ||
576 | num_buckets=args.num_buckets, | ||
577 | progressive_buckets=args.progressive_buckets, | ||
578 | bucket_step_size=args.bucket_step_size, | ||
579 | bucket_max_pixels=args.bucket_max_pixels, | ||
580 | dropout=args.tag_dropout, | ||
581 | shuffle=not args.no_tag_shuffle, | ||
582 | template_key=args.train_data_template, | ||
583 | valid_set_size=args.valid_set_size, | ||
584 | train_set_pad=args.train_set_pad, | ||
585 | valid_set_pad=args.valid_set_pad, | ||
586 | seed=args.seed, | ||
587 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | ||
588 | dtype=weight_dtype | ||
589 | ) | ||
590 | datamodule.setup() | ||
591 | |||
592 | if args.num_class_images != 0: | ||
593 | generate_class_images( | ||
594 | accelerator, | ||
595 | text_encoder, | ||
596 | vae, | ||
597 | unet, | ||
598 | tokenizer, | ||
599 | sample_scheduler, | ||
600 | datamodule.train_dataset, | ||
601 | args.sample_batch_size, | ||
602 | args.sample_image_size, | ||
603 | args.sample_steps | ||
604 | ) | ||
605 | |||
606 | trainer = partial( | 565 | trainer = partial( |
607 | train, | 566 | train, |
608 | accelerator=accelerator, | 567 | accelerator=accelerator, |
@@ -615,63 +574,111 @@ def main(): | |||
615 | callbacks_fn=textual_inversion_strategy | 574 | callbacks_fn=textual_inversion_strategy |
616 | ) | 575 | ) |
617 | 576 | ||
618 | optimizer = optimizer_class( | 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
619 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 578 | range(len(args.placeholder_tokens)), |
620 | lr=args.learning_rate, | 579 | args.placeholder_tokens, |
621 | betas=(args.adam_beta1, args.adam_beta2), | 580 | args.initializer_tokens, |
622 | weight_decay=args.adam_weight_decay, | 581 | args.num_vectors, |
623 | eps=args.adam_epsilon, | 582 | args.train_data_template |
624 | amsgrad=args.adam_amsgrad, | 583 | ): |
625 | ) | 584 | cur_subdir = output_dir.joinpath(placeholder_token) |
585 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
586 | |||
587 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
588 | tokenizer=tokenizer, | ||
589 | embeddings=embeddings, | ||
590 | placeholder_tokens=[placeholder_token], | ||
591 | initializer_tokens=[initializer_token], | ||
592 | num_vectors=[num_vectors] | ||
593 | ) | ||
626 | 594 | ||
627 | lr_scheduler = get_scheduler( | 595 | print( |
628 | args.lr_scheduler, | 596 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") |
629 | optimizer=optimizer, | 597 | |
630 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 598 | args.seed = seed_generator.seed() |
631 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 599 | |
632 | min_lr=args.lr_min_lr, | 600 | datamodule = VlpnDataModule( |
633 | warmup_func=args.lr_warmup_func, | 601 | data_file=args.train_data_file, |
634 | annealing_func=args.lr_annealing_func, | 602 | batch_size=args.train_batch_size, |
635 | warmup_exp=args.lr_warmup_exp, | 603 | tokenizer=tokenizer, |
636 | annealing_exp=args.lr_annealing_exp, | 604 | class_subdir=args.class_image_dir, |
637 | cycles=args.lr_cycles, | 605 | num_class_images=args.num_class_images, |
638 | train_epochs=args.num_train_epochs, | 606 | size=args.resolution, |
639 | warmup_epochs=args.lr_warmup_epochs, | 607 | num_buckets=args.num_buckets, |
640 | ) | 608 | progressive_buckets=args.progressive_buckets, |
641 | 609 | bucket_step_size=args.bucket_step_size, | |
642 | trainer( | 610 | bucket_max_pixels=args.bucket_max_pixels, |
643 | project="textual_inversion", | 611 | dropout=args.tag_dropout, |
644 | train_dataloader=datamodule.train_dataloader, | 612 | shuffle=not args.no_tag_shuffle, |
645 | val_dataloader=datamodule.val_dataloader, | 613 | template_key=data_template, |
646 | optimizer=optimizer, | 614 | valid_set_size=args.valid_set_size, |
647 | lr_scheduler=lr_scheduler, | 615 | train_set_pad=args.train_set_pad, |
648 | num_train_epochs=args.num_train_epochs, | 616 | valid_set_pad=args.valid_set_pad, |
649 | sample_frequency=args.sample_frequency, | 617 | seed=args.seed, |
650 | checkpoint_frequency=args.checkpoint_frequency, | 618 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
651 | global_step_offset=global_step_offset, | 619 | dtype=weight_dtype |
652 | with_prior_preservation=args.num_class_images != 0, | 620 | ) |
653 | prior_loss_weight=args.prior_loss_weight, | 621 | datamodule.setup() |
654 | # -- | 622 | |
655 | tokenizer=tokenizer, | 623 | optimizer = optimizer_class( |
656 | sample_scheduler=sample_scheduler, | 624 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
657 | output_dir=output_dir, | 625 | lr=args.learning_rate, |
658 | placeholder_tokens=args.placeholder_tokens, | 626 | betas=(args.adam_beta1, args.adam_beta2), |
659 | placeholder_token_ids=placeholder_token_ids, | 627 | weight_decay=args.adam_weight_decay, |
660 | learning_rate=args.learning_rate, | 628 | eps=args.adam_epsilon, |
661 | gradient_checkpointing=args.gradient_checkpointing, | 629 | amsgrad=args.adam_amsgrad, |
662 | use_emb_decay=args.use_emb_decay, | 630 | ) |
663 | emb_decay_target=args.emb_decay_target, | 631 | |
664 | emb_decay_factor=args.emb_decay_factor, | 632 | lr_scheduler = get_scheduler( |
665 | emb_decay_start=args.emb_decay_start, | 633 | args.lr_scheduler, |
666 | use_ema=args.use_ema, | 634 | optimizer=optimizer, |
667 | ema_inv_gamma=args.ema_inv_gamma, | 635 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
668 | ema_power=args.ema_power, | 636 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
669 | ema_max_decay=args.ema_max_decay, | 637 | min_lr=args.lr_min_lr, |
670 | sample_batch_size=args.sample_batch_size, | 638 | warmup_func=args.lr_warmup_func, |
671 | sample_num_batches=args.sample_batches, | 639 | annealing_func=args.lr_annealing_func, |
672 | sample_num_steps=args.sample_steps, | 640 | warmup_exp=args.lr_warmup_exp, |
673 | sample_image_size=args.sample_image_size, | 641 | annealing_exp=args.lr_annealing_exp, |
674 | ) | 642 | cycles=args.lr_cycles, |
643 | train_epochs=args.num_train_epochs, | ||
644 | warmup_epochs=args.lr_warmup_epochs, | ||
645 | ) | ||
646 | |||
647 | trainer( | ||
648 | project="textual_inversion", | ||
649 | train_dataloader=datamodule.train_dataloader, | ||
650 | val_dataloader=datamodule.val_dataloader, | ||
651 | optimizer=optimizer, | ||
652 | lr_scheduler=lr_scheduler, | ||
653 | num_train_epochs=args.num_train_epochs, | ||
654 | sample_frequency=args.sample_frequency, | ||
655 | checkpoint_frequency=args.checkpoint_frequency, | ||
656 | global_step_offset=global_step_offset, | ||
657 | with_prior_preservation=args.num_class_images != 0, | ||
658 | prior_loss_weight=args.prior_loss_weight, | ||
659 | # -- | ||
660 | tokenizer=tokenizer, | ||
661 | sample_scheduler=sample_scheduler, | ||
662 | output_dir=cur_subdir, | ||
663 | placeholder_tokens=[placeholder_token], | ||
664 | placeholder_token_ids=placeholder_token_ids, | ||
665 | learning_rate=args.learning_rate, | ||
666 | gradient_checkpointing=args.gradient_checkpointing, | ||
667 | use_emb_decay=args.use_emb_decay, | ||
668 | emb_decay_target=args.emb_decay_target, | ||
669 | emb_decay_factor=args.emb_decay_factor, | ||
670 | emb_decay_start=args.emb_decay_start, | ||
671 | use_ema=args.use_ema, | ||
672 | ema_inv_gamma=args.ema_inv_gamma, | ||
673 | ema_power=args.ema_power, | ||
674 | ema_max_decay=args.ema_max_decay, | ||
675 | sample_batch_size=args.sample_batch_size, | ||
676 | sample_num_batches=args.sample_batches, | ||
677 | sample_num_steps=args.sample_steps, | ||
678 | sample_image_size=args.sample_image_size, | ||
679 | ) | ||
680 | |||
681 | embeddings.persist() | ||
675 | 682 | ||
676 | 683 | ||
677 | if __name__ == "__main__": | 684 | if __name__ == "__main__": |
diff --git a/training/functional.py b/training/functional.py index b6b5d87..1548784 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -515,6 +515,7 @@ def train( | |||
515 | optimizer: torch.optim.Optimizer, | 515 | optimizer: torch.optim.Optimizer, |
516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 516 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
517 | callbacks_fn: Callable[..., TrainingCallbacks], | 517 | callbacks_fn: Callable[..., TrainingCallbacks], |
518 | prepare_unet: bool = False, | ||
518 | num_train_epochs: int = 100, | 519 | num_train_epochs: int = 100, |
519 | sample_frequency: int = 20, | 520 | sample_frequency: int = 20, |
520 | checkpoint_frequency: int = 50, | 521 | checkpoint_frequency: int = 50, |
@@ -523,9 +524,19 @@ def train( | |||
523 | prior_loss_weight: float = 1.0, | 524 | prior_loss_weight: float = 1.0, |
524 | **kwargs, | 525 | **kwargs, |
525 | ): | 526 | ): |
526 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 527 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
527 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 528 | |
528 | ) | 529 | if prepare_unet: |
530 | prep.append(unet) | ||
531 | |||
532 | prep = accelerator.prepare(*prep) | ||
533 | |||
534 | if prepare_unet: | ||
535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | ||
536 | else: | ||
537 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | ||
538 | |||
539 | unet.to(accelerator.device, dtype=dtype) | ||
529 | 540 | ||
530 | vae.to(accelerator.device, dtype=dtype) | 541 | vae.to(accelerator.device, dtype=dtype) |
531 | 542 | ||