diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 21:46:15 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 21:46:15 +0200 |
| commit | b5d3df18c3a56699a3658ad58a02d4494836972f (patch) | |
| tree | 8a43468111eee827564bb5d1561d2d4910915c61 /training/strategy | |
| parent | Update (diff) | |
| download | textual-inversion-diff-b5d3df18c3a56699a3658ad58a02d4494836972f.tar.gz textual-inversion-diff-b5d3df18c3a56699a3658ad58a02d4494836972f.tar.bz2 textual-inversion-diff-b5d3df18c3a56699a3658ad58a02d4494836972f.zip | |
Update
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 88b441b..43fe838 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from types import MethodType | ||
| 2 | from functools import partial | 3 | from functools import partial |
| 3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| @@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks( | |||
| 130 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 131 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 131 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 132 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 132 | 133 | ||
| 134 | unet_.forward = MethodType(unet_.forward, unet_) | ||
| 135 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 136 | |||
| 133 | with ema_context(): | 137 | with ema_context(): |
| 134 | pipeline = VlpnStableDiffusion( | 138 | pipeline = VlpnStableDiffusion( |
| 135 | text_encoder=text_encoder_, | 139 | text_encoder=text_encoder_, |
| @@ -185,6 +189,7 @@ def dreambooth_prepare( | |||
| 185 | train_dataloader: DataLoader, | 189 | train_dataloader: DataLoader, |
| 186 | val_dataloader: Optional[DataLoader], | 190 | val_dataloader: Optional[DataLoader], |
| 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 191 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 192 | text_encoder_unfreeze_last_n_layers: int = 2, | ||
| 188 | **kwargs | 193 | **kwargs |
| 189 | ): | 194 | ): |
| 190 | ( | 195 | ( |
| @@ -198,6 +203,11 @@ def dreambooth_prepare( | |||
| 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 203 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 199 | ) | 204 | ) |
| 200 | 205 | ||
| 206 | for layer in text_encoder.text_model.encoder.layers[ | ||
| 207 | : (-1 * text_encoder_unfreeze_last_n_layers) | ||
| 208 | ]: | ||
| 209 | layer.requires_grad_(False) | ||
| 210 | |||
| 201 | text_encoder.text_model.embeddings.requires_grad_(False) | 211 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 202 | 212 | ||
| 203 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 213 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
