summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-05 10:51:14 +0200
committerVolpeon <git@volpeon.ink>2023-05-05 10:51:14 +0200
commit8d2aa65402c829583e26cdf2c336b8d3057657d6 (patch)
treecc2d47f56d1433e7600abd494361b1ae0a068f80 /train_ti.py
parenttorch.compile won't work yet, keep code prepared (diff)
downloadtextual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.tar.gz
textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.tar.bz2
textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py45
1 files changed, 32 insertions, 13 deletions
diff --git a/train_ti.py b/train_ti.py
index fce4a5e..ae73639 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -237,7 +237,13 @@ def parse_args():
237 "--offset_noise_strength", 237 "--offset_noise_strength",
238 type=float, 238 type=float,
239 default=0, 239 default=0,
240 help="Perlin offset noise strength.", 240 help="Offset noise strength.",
241 )
242 parser.add_argument(
243 "--input_pertubation",
244 type=float,
245 default=0,
246 help="The scale of input pretubation. Recommended 0.1."
241 ) 247 )
242 parser.add_argument( 248 parser.add_argument(
243 "--num_train_epochs", 249 "--num_train_epochs",
@@ -407,6 +413,16 @@ def parse_args():
407 ), 413 ),
408 ) 414 )
409 parser.add_argument( 415 parser.add_argument(
416 "--compile_unet",
417 action="store_true",
418 help="Compile UNet with Torch Dynamo.",
419 )
420 parser.add_argument(
421 "--use_xformers",
422 action="store_true",
423 help="Use xformers.",
424 )
425 parser.add_argument(
410 "--checkpoint_frequency", 426 "--checkpoint_frequency",
411 type=int, 427 type=int,
412 default=999999, 428 default=999999,
@@ -671,23 +687,24 @@ def main():
671 tokenizer.set_dropout(args.vector_dropout) 687 tokenizer.set_dropout(args.vector_dropout)
672 688
673 vae.enable_slicing() 689 vae.enable_slicing()
674 vae.set_use_memory_efficient_attention_xformers(True) 690
675 unet.enable_xformers_memory_efficient_attention() 691 if args.use_xformers:
676 # unet = torch.compile(unet) 692 vae.set_use_memory_efficient_attention_xformers(True)
693 unet.enable_xformers_memory_efficient_attention()
677 694
678 if args.gradient_checkpointing: 695 if args.gradient_checkpointing:
679 unet.enable_gradient_checkpointing() 696 unet.enable_gradient_checkpointing()
680 text_encoder.gradient_checkpointing_enable() 697 text_encoder.gradient_checkpointing_enable()
681 698
682 convnext = create_model( 699 # convnext = create_model(
683 "convnext_tiny", 700 # "convnext_tiny",
684 pretrained=False, 701 # pretrained=False,
685 num_classes=3, 702 # num_classes=3,
686 drop_path_rate=0.0, 703 # drop_path_rate=0.0,
687 ) 704 # )
688 convnext.to(accelerator.device, dtype=weight_dtype) 705 # convnext.to(accelerator.device, dtype=weight_dtype)
689 convnext.requires_grad_(False) 706 # convnext.requires_grad_(False)
690 convnext.eval() 707 # convnext.eval()
691 708
692 if len(args.alias_tokens) != 0: 709 if len(args.alias_tokens) != 0:
693 alias_placeholder_tokens = args.alias_tokens[::2] 710 alias_placeholder_tokens = args.alias_tokens[::2]
@@ -822,6 +839,7 @@ def main():
822 noise_scheduler=noise_scheduler, 839 noise_scheduler=noise_scheduler,
823 dtype=weight_dtype, 840 dtype=weight_dtype,
824 seed=args.seed, 841 seed=args.seed,
842 compile_unet=args.compile_unet,
825 guidance_scale=args.guidance_scale, 843 guidance_scale=args.guidance_scale,
826 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 844 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
827 no_val=args.valid_set_size == 0, 845 no_val=args.valid_set_size == 0,
@@ -831,6 +849,7 @@ def main():
831 milestone_checkpoints=not args.no_milestone_checkpoints, 849 milestone_checkpoints=not args.no_milestone_checkpoints,
832 global_step_offset=global_step_offset, 850 global_step_offset=global_step_offset,
833 offset_noise_strength=args.offset_noise_strength, 851 offset_noise_strength=args.offset_noise_strength,
852 input_pertubation=args.input_pertubation,
834 # -- 853 # --
835 use_emb_decay=args.use_emb_decay, 854 use_emb_decay=args.use_emb_decay,
836 emb_decay_target=args.emb_decay_target, 855 emb_decay_target=args.emb_decay_target,