From 036e33fbde6bad7c48bb6f6b3d695b7908535c64 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 3 Nov 2022 17:56:08 +0100 Subject: Update --- textual_inversion.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 115f3aa..578c054 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -25,6 +25,7 @@ from slugify import slugify from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule +from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -162,10 +163,10 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="cosine_with_restarts", + default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' + ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( @@ -535,6 +536,8 @@ def main(): prompt_processor = PromptProcessor(tokenizer, text_encoder) + unet.set_use_memory_efficient_attention_xformers(True) + if args.gradient_checkpointing: text_encoder.gradient_checkpointing_enable() @@ -693,7 +696,12 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if args.lr_scheduler == "cosine_with_restarts": + if args.lr_scheduler == "one_cycle": + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, -- cgit v1.2.3-54-g00ecf