summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py2
4 files changed, 4 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index e14aeea..46d25f6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -644,11 +644,9 @@ def train(
644 min_snr_gamma: int = 5, 644 min_snr_gamma: int = 5,
645 **kwargs, 645 **kwargs,
646): 646):
647 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 647 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
648 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) 648 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs)
649 649
650 kwargs.update(extra)
651
652 vae.to(accelerator.device, dtype=dtype) 650 vae.to(accelerator.device, dtype=dtype)
653 vae.requires_grad_(False) 651 vae.requires_grad_(False)
654 vae.eval() 652 vae.eval()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 695174a..42624cd 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -198,7 +198,7 @@ def dreambooth_prepare(
198 198
199 text_encoder.text_model.embeddings.requires_grad_(False) 199 text_encoder.text_model.embeddings.requires_grad_(False)
200 200
201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
202 202
203 203
204dreambooth_strategy = TrainingStrategy( 204dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index ae85401..73ec8f2 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -184,7 +184,7 @@ def lora_prepare(
184 184
185 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) 185 text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True)
186 186
187 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 187 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
188 188
189 189
190lora_strategy = TrainingStrategy( 190lora_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9cdc1bb..363c3f9 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -207,7 +207,7 @@ def textual_inversion_prepare(
207 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 207 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
208 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) 208 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
209 209
210 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 210 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
211 211
212 212
213textual_inversion_strategy = TrainingStrategy( 213textual_inversion_strategy = TrainingStrategy(