From 8d2aa65402c829583e26cdf2c336b8d3057657d6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 5 May 2023 10:51:14 +0200 Subject: Update --- train_ti.py | 45 ++++++++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 13 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index fce4a5e..ae73639 100644 --- a/train_ti.py +++ b/train_ti.py @@ -237,7 +237,13 @@ def parse_args(): "--offset_noise_strength", type=float, default=0, - help="Perlin offset noise strength.", + help="Offset noise strength.", + ) + parser.add_argument( + "--input_pertubation", + type=float, + default=0, + help="The scale of input pretubation. Recommended 0.1." ) parser.add_argument( "--num_train_epochs", @@ -406,6 +412,16 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--compile_unet", + action="store_true", + help="Compile UNet with Torch Dynamo.", + ) + parser.add_argument( + "--use_xformers", + action="store_true", + help="Use xformers.", + ) parser.add_argument( "--checkpoint_frequency", type=int, @@ -671,23 +687,24 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() - # unet = torch.compile(unet) + + if args.use_xformers: + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - convnext = create_model( - "convnext_tiny", - pretrained=False, - num_classes=3, - drop_path_rate=0.0, - ) - convnext.to(accelerator.device, dtype=weight_dtype) - convnext.requires_grad_(False) - convnext.eval() + # convnext = create_model( + # "convnext_tiny", + # pretrained=False, + # num_classes=3, + # drop_path_rate=0.0, + # ) + # convnext.to(accelerator.device, dtype=weight_dtype) + # convnext.requires_grad_(False) + # convnext.eval() if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] @@ -822,6 +839,7 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, + compile_unet=args.compile_unet, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, @@ -831,6 +849,7 @@ def main(): milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, offset_noise_strength=args.offset_noise_strength, + input_pertubation=args.input_pertubation, # -- use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, -- cgit v1.2.3-54-g00ecf