diff options
| author | Volpeon <git@volpeon.ink> | 2023-05-05 10:51:14 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-05-05 10:51:14 +0200 |
| commit | 8d2aa65402c829583e26cdf2c336b8d3057657d6 (patch) | |
| tree | cc2d47f56d1433e7600abd494361b1ae0a068f80 /train_ti.py | |
| parent | torch.compile won't work yet, keep code prepared (diff) | |
| download | textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.tar.gz textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.tar.bz2 textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 45 |
1 files changed, 32 insertions, 13 deletions
diff --git a/train_ti.py b/train_ti.py index fce4a5e..ae73639 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -237,7 +237,13 @@ def parse_args(): | |||
| 237 | "--offset_noise_strength", | 237 | "--offset_noise_strength", |
| 238 | type=float, | 238 | type=float, |
| 239 | default=0, | 239 | default=0, |
| 240 | help="Perlin offset noise strength.", | 240 | help="Offset noise strength.", |
| 241 | ) | ||
| 242 | parser.add_argument( | ||
| 243 | "--input_pertubation", | ||
| 244 | type=float, | ||
| 245 | default=0, | ||
| 246 | help="The scale of input pretubation. Recommended 0.1." | ||
| 241 | ) | 247 | ) |
| 242 | parser.add_argument( | 248 | parser.add_argument( |
| 243 | "--num_train_epochs", | 249 | "--num_train_epochs", |
| @@ -407,6 +413,16 @@ def parse_args(): | |||
| 407 | ), | 413 | ), |
| 408 | ) | 414 | ) |
| 409 | parser.add_argument( | 415 | parser.add_argument( |
| 416 | "--compile_unet", | ||
| 417 | action="store_true", | ||
| 418 | help="Compile UNet with Torch Dynamo.", | ||
| 419 | ) | ||
| 420 | parser.add_argument( | ||
| 421 | "--use_xformers", | ||
| 422 | action="store_true", | ||
| 423 | help="Use xformers.", | ||
| 424 | ) | ||
| 425 | parser.add_argument( | ||
| 410 | "--checkpoint_frequency", | 426 | "--checkpoint_frequency", |
| 411 | type=int, | 427 | type=int, |
| 412 | default=999999, | 428 | default=999999, |
| @@ -671,23 +687,24 @@ def main(): | |||
| 671 | tokenizer.set_dropout(args.vector_dropout) | 687 | tokenizer.set_dropout(args.vector_dropout) |
| 672 | 688 | ||
| 673 | vae.enable_slicing() | 689 | vae.enable_slicing() |
| 674 | vae.set_use_memory_efficient_attention_xformers(True) | 690 | |
| 675 | unet.enable_xformers_memory_efficient_attention() | 691 | if args.use_xformers: |
| 676 | # unet = torch.compile(unet) | 692 | vae.set_use_memory_efficient_attention_xformers(True) |
| 693 | unet.enable_xformers_memory_efficient_attention() | ||
| 677 | 694 | ||
| 678 | if args.gradient_checkpointing: | 695 | if args.gradient_checkpointing: |
| 679 | unet.enable_gradient_checkpointing() | 696 | unet.enable_gradient_checkpointing() |
| 680 | text_encoder.gradient_checkpointing_enable() | 697 | text_encoder.gradient_checkpointing_enable() |
| 681 | 698 | ||
| 682 | convnext = create_model( | 699 | # convnext = create_model( |
| 683 | "convnext_tiny", | 700 | # "convnext_tiny", |
| 684 | pretrained=False, | 701 | # pretrained=False, |
| 685 | num_classes=3, | 702 | # num_classes=3, |
| 686 | drop_path_rate=0.0, | 703 | # drop_path_rate=0.0, |
| 687 | ) | 704 | # ) |
| 688 | convnext.to(accelerator.device, dtype=weight_dtype) | 705 | # convnext.to(accelerator.device, dtype=weight_dtype) |
| 689 | convnext.requires_grad_(False) | 706 | # convnext.requires_grad_(False) |
| 690 | convnext.eval() | 707 | # convnext.eval() |
| 691 | 708 | ||
| 692 | if len(args.alias_tokens) != 0: | 709 | if len(args.alias_tokens) != 0: |
| 693 | alias_placeholder_tokens = args.alias_tokens[::2] | 710 | alias_placeholder_tokens = args.alias_tokens[::2] |
| @@ -822,6 +839,7 @@ def main(): | |||
| 822 | noise_scheduler=noise_scheduler, | 839 | noise_scheduler=noise_scheduler, |
| 823 | dtype=weight_dtype, | 840 | dtype=weight_dtype, |
| 824 | seed=args.seed, | 841 | seed=args.seed, |
| 842 | compile_unet=args.compile_unet, | ||
| 825 | guidance_scale=args.guidance_scale, | 843 | guidance_scale=args.guidance_scale, |
| 826 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 844 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 827 | no_val=args.valid_set_size == 0, | 845 | no_val=args.valid_set_size == 0, |
| @@ -831,6 +849,7 @@ def main(): | |||
| 831 | milestone_checkpoints=not args.no_milestone_checkpoints, | 849 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 832 | global_step_offset=global_step_offset, | 850 | global_step_offset=global_step_offset, |
| 833 | offset_noise_strength=args.offset_noise_strength, | 851 | offset_noise_strength=args.offset_noise_strength, |
| 852 | input_pertubation=args.input_pertubation, | ||
| 834 | # -- | 853 | # -- |
| 835 | use_emb_decay=args.use_emb_decay, | 854 | use_emb_decay=args.use_emb_decay, |
| 836 | emb_decay_target=args.emb_decay_target, | 855 | emb_decay_target=args.emb_decay_target, |
