From 3f922880475c2c0a5679987d4a9a43606e838566 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 22:26:43 +0100 Subject: Added Dreambooth strategy --- train_ti.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 77dec12..2497519 100644 --- a/train_ti.py +++ b/train_ti.py @@ -557,15 +557,6 @@ def main(): else: optimizer_class = torch.optim.AdamW - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 @@ -624,6 +615,29 @@ def main(): args.sample_steps ) + trainer = partial( + train, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + noise_scheduler=noise_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + seed=args.seed, + callbacks_fn=textual_inversion_strategy + ) + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + if args.find_lr: lr_scheduler = None else: @@ -642,20 +656,6 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - trainer = partial( - train, - accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, - vae=vae, - noise_scheduler=noise_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - dtype=weight_dtype, - seed=args.seed, - callbacks_fn=textual_inversion_strategy - ) - trainer( optimizer=optimizer, lr_scheduler=lr_scheduler, -- cgit v1.2.3-54-g00ecf