diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 45 |
1 files changed, 32 insertions, 13 deletions
diff --git a/train_ti.py b/train_ti.py index fce4a5e..ae73639 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -237,7 +237,13 @@ def parse_args(): | |||
237 | "--offset_noise_strength", | 237 | "--offset_noise_strength", |
238 | type=float, | 238 | type=float, |
239 | default=0, | 239 | default=0, |
240 | help="Perlin offset noise strength.", | 240 | help="Offset noise strength.", |
241 | ) | ||
242 | parser.add_argument( | ||
243 | "--input_pertubation", | ||
244 | type=float, | ||
245 | default=0, | ||
246 | help="The scale of input pretubation. Recommended 0.1." | ||
241 | ) | 247 | ) |
242 | parser.add_argument( | 248 | parser.add_argument( |
243 | "--num_train_epochs", | 249 | "--num_train_epochs", |
@@ -407,6 +413,16 @@ def parse_args(): | |||
407 | ), | 413 | ), |
408 | ) | 414 | ) |
409 | parser.add_argument( | 415 | parser.add_argument( |
416 | "--compile_unet", | ||
417 | action="store_true", | ||
418 | help="Compile UNet with Torch Dynamo.", | ||
419 | ) | ||
420 | parser.add_argument( | ||
421 | "--use_xformers", | ||
422 | action="store_true", | ||
423 | help="Use xformers.", | ||
424 | ) | ||
425 | parser.add_argument( | ||
410 | "--checkpoint_frequency", | 426 | "--checkpoint_frequency", |
411 | type=int, | 427 | type=int, |
412 | default=999999, | 428 | default=999999, |
@@ -671,23 +687,24 @@ def main(): | |||
671 | tokenizer.set_dropout(args.vector_dropout) | 687 | tokenizer.set_dropout(args.vector_dropout) |
672 | 688 | ||
673 | vae.enable_slicing() | 689 | vae.enable_slicing() |
674 | vae.set_use_memory_efficient_attention_xformers(True) | 690 | |
675 | unet.enable_xformers_memory_efficient_attention() | 691 | if args.use_xformers: |
676 | # unet = torch.compile(unet) | 692 | vae.set_use_memory_efficient_attention_xformers(True) |
693 | unet.enable_xformers_memory_efficient_attention() | ||
677 | 694 | ||
678 | if args.gradient_checkpointing: | 695 | if args.gradient_checkpointing: |
679 | unet.enable_gradient_checkpointing() | 696 | unet.enable_gradient_checkpointing() |
680 | text_encoder.gradient_checkpointing_enable() | 697 | text_encoder.gradient_checkpointing_enable() |
681 | 698 | ||
682 | convnext = create_model( | 699 | # convnext = create_model( |
683 | "convnext_tiny", | 700 | # "convnext_tiny", |
684 | pretrained=False, | 701 | # pretrained=False, |
685 | num_classes=3, | 702 | # num_classes=3, |
686 | drop_path_rate=0.0, | 703 | # drop_path_rate=0.0, |
687 | ) | 704 | # ) |
688 | convnext.to(accelerator.device, dtype=weight_dtype) | 705 | # convnext.to(accelerator.device, dtype=weight_dtype) |
689 | convnext.requires_grad_(False) | 706 | # convnext.requires_grad_(False) |
690 | convnext.eval() | 707 | # convnext.eval() |
691 | 708 | ||
692 | if len(args.alias_tokens) != 0: | 709 | if len(args.alias_tokens) != 0: |
693 | alias_placeholder_tokens = args.alias_tokens[::2] | 710 | alias_placeholder_tokens = args.alias_tokens[::2] |
@@ -822,6 +839,7 @@ def main(): | |||
822 | noise_scheduler=noise_scheduler, | 839 | noise_scheduler=noise_scheduler, |
823 | dtype=weight_dtype, | 840 | dtype=weight_dtype, |
824 | seed=args.seed, | 841 | seed=args.seed, |
842 | compile_unet=args.compile_unet, | ||
825 | guidance_scale=args.guidance_scale, | 843 | guidance_scale=args.guidance_scale, |
826 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 844 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
827 | no_val=args.valid_set_size == 0, | 845 | no_val=args.valid_set_size == 0, |
@@ -831,6 +849,7 @@ def main(): | |||
831 | milestone_checkpoints=not args.no_milestone_checkpoints, | 849 | milestone_checkpoints=not args.no_milestone_checkpoints, |
832 | global_step_offset=global_step_offset, | 850 | global_step_offset=global_step_offset, |
833 | offset_noise_strength=args.offset_noise_strength, | 851 | offset_noise_strength=args.offset_noise_strength, |
852 | input_pertubation=args.input_pertubation, | ||
834 | # -- | 853 | # -- |
835 | use_emb_decay=args.use_emb_decay, | 854 | use_emb_decay=args.use_emb_decay, |
836 | emb_decay_target=args.emb_decay_target, | 855 | emb_decay_target=args.emb_decay_target, |