diff options
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 2038e34..10bc6d7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks( | |||
78 | power=ema_power, | 78 | power=ema_power, |
79 | max_value=ema_max_decay, | 79 | max_value=ema_max_decay, |
80 | ) | 80 | ) |
81 | ema_embeddings.to(accelerator.device) | ||
81 | else: | 82 | else: |
82 | ema_embeddings = None | 83 | ema_embeddings = None |
83 | 84 | ||
@@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks( | |||
92 | def on_accum_model(): | 93 | def on_accum_model(): |
93 | return text_encoder.text_model.embeddings.temp_token_embedding | 94 | return text_encoder.text_model.embeddings.temp_token_embedding |
94 | 95 | ||
95 | def on_prepare(): | ||
96 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
97 | |||
98 | if ema_embeddings is not None: | ||
99 | ema_embeddings.to(accelerator.device) | ||
100 | |||
101 | if gradient_checkpointing: | ||
102 | unet.train() | ||
103 | |||
104 | @contextmanager | 96 | @contextmanager |
105 | def on_train(epoch: int): | 97 | def on_train(epoch: int): |
106 | tokenizer.train() | 98 | tokenizer.train() |
@@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks( | |||
177 | torch.cuda.empty_cache() | 169 | torch.cuda.empty_cache() |
178 | 170 | ||
179 | return TrainingCallbacks( | 171 | return TrainingCallbacks( |
180 | on_prepare=on_prepare, | ||
181 | on_accum_model=on_accum_model, | 172 | on_accum_model=on_accum_model, |
182 | on_train=on_train, | 173 | on_train=on_train, |
183 | on_eval=on_eval, | 174 | on_eval=on_eval, |
@@ -197,6 +188,7 @@ def textual_inversion_prepare( | |||
197 | train_dataloader: DataLoader, | 188 | train_dataloader: DataLoader, |
198 | val_dataloader: Optional[DataLoader], | 189 | val_dataloader: Optional[DataLoader], |
199 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 190 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
191 | gradient_checkpointing: bool = False, | ||
200 | **kwargs | 192 | **kwargs |
201 | ): | 193 | ): |
202 | weight_dtype = torch.float32 | 194 | weight_dtype = torch.float32 |
@@ -207,7 +199,17 @@ def textual_inversion_prepare( | |||
207 | 199 | ||
208 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 200 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
209 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 201 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
202 | |||
210 | unet.to(accelerator.device, dtype=weight_dtype) | 203 | unet.to(accelerator.device, dtype=weight_dtype) |
204 | unet.requires_grad_(False) | ||
205 | unet.eval() | ||
206 | if gradient_checkpointing: | ||
207 | unet.train() | ||
208 | |||
209 | text_encoder.text_model.encoder.requires_grad_(False) | ||
210 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | ||
211 | text_encoder.eval() | ||
212 | |||
211 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 213 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
212 | 214 | ||
213 | 215 | ||