From a1b8327085ddeab589be074d7e9df4291aba1210 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 1 Mar 2023 12:34:42 +0100 Subject: Update --- training/strategy/dreambooth.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0290327..e5e84c8 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks( def on_prepare(): unet.requires_grad_(True) - text_encoder.requires_grad_(True) - text_encoder.text_model.embeddings.requires_grad_(False) + text_encoder.text_model.encoder.requires_grad_(True) + text_encoder.text_model.final_layer_norm.requires_grad_(True) if ema_unet is not None: ema_unet.to(accelerator.device) @@ -203,7 +203,7 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) dreambooth_strategy = TrainingStrategy( -- cgit v1.2.3-54-g00ecf