From b5d3df18c3a56699a3658ad58a02d4494836972f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 21:46:15 +0200 Subject: Update --- training/strategy/dreambooth.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 88b441b..43fe838 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -1,4 +1,5 @@ from typing import Optional +from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path @@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + unet_.forward = MethodType(unet_.forward, unet_) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, @@ -185,6 +189,7 @@ def dreambooth_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + text_encoder_unfreeze_last_n_layers: int = 2, **kwargs ): ( @@ -198,6 +203,11 @@ def dreambooth_prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) + for layer in text_encoder.text_model.encoder.layers[ + : (-1 * text_encoder_unfreeze_last_n_layers) + ]: + layer.requires_grad_(False) + text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2