diff options
author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /training/strategy | |
parent | Fixed TI normalization order (diff) | |
download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip |
Update
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 6 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
2 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0290327..e5e84c8 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks( | |||
88 | 88 | ||
89 | def on_prepare(): | 89 | def on_prepare(): |
90 | unet.requires_grad_(True) | 90 | unet.requires_grad_(True) |
91 | text_encoder.requires_grad_(True) | 91 | text_encoder.text_model.encoder.requires_grad_(True) |
92 | text_encoder.text_model.embeddings.requires_grad_(False) | 92 | text_encoder.text_model.final_layer_norm.requires_grad_(True) |
93 | 93 | ||
94 | if ema_unet is not None: | 94 | if ema_unet is not None: |
95 | ema_unet.to(accelerator.device) | 95 | ema_unet.to(accelerator.device) |
@@ -203,7 +203,7 @@ def dreambooth_prepare( | |||
203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
204 | **kwargs | 204 | **kwargs |
205 | ): | 205 | ): |
206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) | 206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) |
207 | 207 | ||
208 | 208 | ||
209 | dreambooth_strategy = TrainingStrategy( | 209 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 732cd74..bd0d178 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks( | |||
130 | if lambda_ != 0: | 130 | if lambda_ != 0: |
131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
132 | 132 | ||
133 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.shape[0], dtype=torch.bool) |
134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True |
135 | mask[zero_ids] = False | 135 | mask[zero_ids] = False |
136 | 136 | ||