From 95adaea8b55d8e3755c035758bc649ae22548572 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 24 Mar 2023 10:53:16 +0100 Subject: Refactoring, fixed Lora training --- training/strategy/dreambooth.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 28fccff..9808027 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks( power=ema_power, max_value=ema_max_decay, ) + ema_unet.to(accelerator.device) else: ema_unet = None @@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks( def on_accum_model(): return unet - def on_prepare(): - unet.requires_grad_(True) - text_encoder.text_model.encoder.requires_grad_(True) - text_encoder.text_model.final_layer_norm.requires_grad_(True) - - if ema_unet is not None: - ema_unet.to(accelerator.device) - @contextmanager def on_train(epoch: int): tokenizer.train() @@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_prepare=on_prepare, on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, @@ -203,7 +195,12 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + + text_encoder.text_model.embeddings.requires_grad_(False) + + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} dreambooth_strategy = TrainingStrategy( -- cgit v1.2.3-54-g00ecf