summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py2
3 files changed, 3 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 695174a..42624cd 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -198,7 +198,7 @@ def dreambooth_prepare(
198 198
199 text_encoder.text_model.embeddings.requires_grad_(False) 199 text_encoder.text_model.embeddings.requires_grad_(False)
200 200
201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
202 202
203 203
204dreambooth_strategy = TrainingStrategy( 204dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index ae85401..73ec8f2 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -184,7 +184,7 @@ def lora_prepare(
184 184
185 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) 185 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True)
186 186
187 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 187 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
188 188
189 189
190lora_strategy = TrainingStrategy( 190lora_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9cdc1bb..363c3f9 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -207,7 +207,7 @@ def textual_inversion_prepare(
207 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 207 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
208 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 208 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
209 209
210 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 210 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
211 211
212 212
213textual_inversion_strategy = TrainingStrategy( 213textual_inversion_strategy = TrainingStrategy(