diff options
author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /training/strategy/dreambooth.py | |
parent | Fixed TI normalization order (diff) | |
download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip |
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r-- | training/strategy/dreambooth.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0290327..e5e84c8 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks( | |||
88 | 88 | ||
89 | def on_prepare(): | 89 | def on_prepare(): |
90 | unet.requires_grad_(True) | 90 | unet.requires_grad_(True) |
91 | text_encoder.requires_grad_(True) | 91 | text_encoder.text_model.encoder.requires_grad_(True) |
92 | text_encoder.text_model.embeddings.requires_grad_(False) | 92 | text_encoder.text_model.final_layer_norm.requires_grad_(True) |
93 | 93 | ||
94 | if ema_unet is not None: | 94 | if ema_unet is not None: |
95 | ema_unet.to(accelerator.device) | 95 | ema_unet.to(accelerator.device) |
@@ -203,7 +203,7 @@ def dreambooth_prepare( | |||
203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
204 | **kwargs | 204 | **kwargs |
205 | ): | 205 | ): |
206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) | 206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) |
207 | 207 | ||
208 | 208 | ||
209 | dreambooth_strategy = TrainingStrategy( | 209 | dreambooth_strategy = TrainingStrategy( |