diff options
Diffstat (limited to 'training/strategy/ti.py')
| -rw-r--r-- | training/strategy/ti.py | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 2038e34..10bc6d7 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks( | |||
| 78 | power=ema_power, | 78 | power=ema_power, |
| 79 | max_value=ema_max_decay, | 79 | max_value=ema_max_decay, |
| 80 | ) | 80 | ) |
| 81 | ema_embeddings.to(accelerator.device) | ||
| 81 | else: | 82 | else: |
| 82 | ema_embeddings = None | 83 | ema_embeddings = None |
| 83 | 84 | ||
| @@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks( | |||
| 92 | def on_accum_model(): | 93 | def on_accum_model(): |
| 93 | return text_encoder.text_model.embeddings.temp_token_embedding | 94 | return text_encoder.text_model.embeddings.temp_token_embedding |
| 94 | 95 | ||
| 95 | def on_prepare(): | ||
| 96 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
| 97 | |||
| 98 | if ema_embeddings is not None: | ||
| 99 | ema_embeddings.to(accelerator.device) | ||
| 100 | |||
| 101 | if gradient_checkpointing: | ||
| 102 | unet.train() | ||
| 103 | |||
| 104 | @contextmanager | 96 | @contextmanager |
| 105 | def on_train(epoch: int): | 97 | def on_train(epoch: int): |
| 106 | tokenizer.train() | 98 | tokenizer.train() |
| @@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks( | |||
| 177 | torch.cuda.empty_cache() | 169 | torch.cuda.empty_cache() |
| 178 | 170 | ||
| 179 | return TrainingCallbacks( | 171 | return TrainingCallbacks( |
| 180 | on_prepare=on_prepare, | ||
| 181 | on_accum_model=on_accum_model, | 172 | on_accum_model=on_accum_model, |
| 182 | on_train=on_train, | 173 | on_train=on_train, |
| 183 | on_eval=on_eval, | 174 | on_eval=on_eval, |
| @@ -197,6 +188,7 @@ def textual_inversion_prepare( | |||
| 197 | train_dataloader: DataLoader, | 188 | train_dataloader: DataLoader, |
| 198 | val_dataloader: Optional[DataLoader], | 189 | val_dataloader: Optional[DataLoader], |
| 199 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 190 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 191 | gradient_checkpointing: bool = False, | ||
| 200 | **kwargs | 192 | **kwargs |
| 201 | ): | 193 | ): |
| 202 | weight_dtype = torch.float32 | 194 | weight_dtype = torch.float32 |
| @@ -207,7 +199,17 @@ def textual_inversion_prepare( | |||
| 207 | 199 | ||
| 208 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 200 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 209 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 201 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) |
| 202 | |||
| 210 | unet.to(accelerator.device, dtype=weight_dtype) | 203 | unet.to(accelerator.device, dtype=weight_dtype) |
| 204 | unet.requires_grad_(False) | ||
| 205 | unet.eval() | ||
| 206 | if gradient_checkpointing: | ||
| 207 | unet.train() | ||
| 208 | |||
| 209 | text_encoder.text_model.encoder.requires_grad_(False) | ||
| 210 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | ||
| 211 | text_encoder.eval() | ||
| 212 | |||
| 211 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | 213 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
| 212 | 214 | ||
| 213 | 215 | ||
