From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- training/strategy/dreambooth.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e88bf90..b4c77f3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks( save_samples_ = partial( save_samples, accelerator=accelerator, - unet=unet, - text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=sample_output_dir, seed=seed, batch_size=sample_batch_size, @@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks( else: return nullcontext() - def on_model(): + def on_accum_model(): return unet def on_prepare(): @@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - save_samples_(step=step) + unet_ = accelerator.unwrap_model(unet) + text_encoder_ = accelerator.unwrap_model(text_encoder) + + orig_unet_dtype = unet_.dtype + orig_text_encoder_dtype = text_encoder_.dtype + + unet_.to(dtype=weight_dtype) + text_encoder_.to(dtype=weight_dtype) + + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) + + unet_.to(dtype=orig_unet_dtype) + text_encoder_.to(dtype=orig_text_encoder_dtype) + + del unet_ + del text_encoder_ + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return TrainingCallbacks( on_prepare=on_prepare, - on_model=on_model, + on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, @@ -191,9 +206,13 @@ def dreambooth_prepare( accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ): - return accelerator.prepare(text_encoder, unet, *args) + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) dreambooth_strategy = TrainingStrategy( -- cgit v1.2.3-54-g00ecf