From 036e33fbde6bad7c48bb6f6b3d695b7908535c64 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 3 Nov 2022 17:56:08 +0100 Subject: Update --- dreambooth.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index c0caf03..8c4bf50 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -26,6 +26,7 @@ from slugify import slugify from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule +from training.optimization import get_one_cycle_schedule from models.clip.prompt import PromptProcessor logger = get_logger(__name__) @@ -178,10 +179,10 @@ def parse_args(): parser.add_argument( "--lr_scheduler", type=str, - default="cosine_with_restarts", + default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' + ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( @@ -585,6 +586,8 @@ def main(): device=accelerator.device ) + unet.set_use_memory_efficient_attention_xformers(True) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() @@ -784,7 +787,12 @@ def main(): overrode_max_train_steps = True num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if args.lr_scheduler == "cosine_with_restarts": + if args.lr_scheduler == "one_cycle": + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, -- cgit v1.2.3-54-g00ecf