summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py6
-rw-r--r--training/strategy/ti.py2
2 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0290327..e5e84c8 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks(
88 88
89 def on_prepare(): 89 def on_prepare():
90 unet.requires_grad_(True) 90 unet.requires_grad_(True)
91 text_encoder.requires_grad_(True) 91 text_encoder.text_model.encoder.requires_grad_(True)
92 text_encoder.text_model.embeddings.requires_grad_(False) 92 text_encoder.text_model.final_layer_norm.requires_grad_(True)
93 93
94 if ema_unet is not None: 94 if ema_unet is not None:
95 ema_unet.to(accelerator.device) 95 ema_unet.to(accelerator.device)
@@ -203,7 +203,7 @@ def dreambooth_prepare(
203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
204 **kwargs 204 **kwargs
205): 205):
206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) 206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},)
207 207
208 208
209dreambooth_strategy = TrainingStrategy( 209dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 732cd74..bd0d178 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks(
130 if lambda_ != 0: 130 if lambda_ != 0:
131 w = text_encoder.text_model.embeddings.temp_token_embedding.weight 131 w = text_encoder.text_model.embeddings.temp_token_embedding.weight
132 132
133 mask = torch.zeros(w.size(0), dtype=torch.bool) 133 mask = torch.zeros(w.shape[0], dtype=torch.bool)
134 mask[text_encoder.text_model.embeddings.temp_token_ids] = True 134 mask[text_encoder.text_model.embeddings.temp_token_ids] = True
135 mask[zero_ids] = False 135 mask[zero_ids] = False
136 136