summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 2038e34..10bc6d7 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks(
78 power=ema_power, 78 power=ema_power,
79 max_value=ema_max_decay, 79 max_value=ema_max_decay,
80 ) 80 )
81 ema_embeddings.to(accelerator.device)
81 else: 82 else:
82 ema_embeddings = None 83 ema_embeddings = None
83 84
@@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks(
92 def on_accum_model(): 93 def on_accum_model():
93 return text_encoder.text_model.embeddings.temp_token_embedding 94 return text_encoder.text_model.embeddings.temp_token_embedding
94 95
95 def on_prepare():
96 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
97
98 if ema_embeddings is not None:
99 ema_embeddings.to(accelerator.device)
100
101 if gradient_checkpointing:
102 unet.train()
103
104 @contextmanager 96 @contextmanager
105 def on_train(epoch: int): 97 def on_train(epoch: int):
106 tokenizer.train() 98 tokenizer.train()
@@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks(
177 torch.cuda.empty_cache() 169 torch.cuda.empty_cache()
178 170
179 return TrainingCallbacks( 171 return TrainingCallbacks(
180 on_prepare=on_prepare,
181 on_accum_model=on_accum_model, 172 on_accum_model=on_accum_model,
182 on_train=on_train, 173 on_train=on_train,
183 on_eval=on_eval, 174 on_eval=on_eval,
@@ -197,6 +188,7 @@ def textual_inversion_prepare(
197 train_dataloader: DataLoader, 188 train_dataloader: DataLoader,
198 val_dataloader: Optional[DataLoader], 189 val_dataloader: Optional[DataLoader],
199 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 190 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
191 gradient_checkpointing: bool = False,
200 **kwargs 192 **kwargs
201): 193):
202 weight_dtype = torch.float32 194 weight_dtype = torch.float32
@@ -207,7 +199,17 @@ def textual_inversion_prepare(
207 199
208 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 200 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
209 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 201 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler)
202
210 unet.to(accelerator.device, dtype=weight_dtype) 203 unet.to(accelerator.device, dtype=weight_dtype)
204 unet.requires_grad_(False)
205 unet.eval()
206 if gradient_checkpointing:
207 unet.train()
208
209 text_encoder.text_model.encoder.requires_grad_(False)
210 text_encoder.text_model.final_layer_norm.requires_grad_(False)
211 text_encoder.eval()
212
211 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 213 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
212 214
213 215