summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 88b441b..43fe838 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks(
130 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 131 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
131 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 132 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
132 133
134 unet_.forward = MethodType(unet_.forward, unet_)
135 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
136
133 with ema_context(): 137 with ema_context():
134 pipeline = VlpnStableDiffusion( 138 pipeline = VlpnStableDiffusion(
135 text_encoder=text_encoder_, 139 text_encoder=text_encoder_,
@@ -185,6 +189,7 @@ def dreambooth_prepare(
185 train_dataloader: DataLoader, 189 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 190 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 191 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
192 text_encoder_unfreeze_last_n_layers: int = 2,
188 **kwargs 193 **kwargs
189): 194):
190 ( 195 (
@@ -198,6 +203,11 @@ def dreambooth_prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 ) 204 )
200 205
206 for layer in text_encoder.text_model.encoder.layers[
207 : (-1 * text_encoder_unfreeze_last_n_layers)
208 ]:
209 layer.requires_grad_(False)
210
201 text_encoder.text_model.embeddings.requires_grad_(False) 211 text_encoder.text_model.embeddings.requires_grad_(False)
202 212
203 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 213 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler