diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-17 14:53:25 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-17 14:53:25 +0100 |
| commit | 842f26654bbe7dfd2f45df1fd2660d3f902af8cc (patch) | |
| tree | 3e7cd2dea37f025f9aa2755a893efd29195c7396 /training/strategy | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-842f26654bbe7dfd2f45df1fd2660d3f902af8cc.tar.gz textual-inversion-diff-842f26654bbe7dfd2f45df1fd2660d3f902af8cc.tar.bz2 textual-inversion-diff-842f26654bbe7dfd2f45df1fd2660d3f902af8cc.zip | |
Remove xformers, switch to Pytorch Nightly
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 8 | ||||
| -rw-r--r-- | training/strategy/lora.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 4 |
3 files changed, 7 insertions, 7 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 8aaed3a..d697554 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -144,8 +144,8 @@ def dreambooth_strategy_callbacks( | |||
| 144 | 144 | ||
| 145 | print("Saving model...") | 145 | print("Saving model...") |
| 146 | 146 | ||
| 147 | unet_ = accelerator.unwrap_model(unet) | 147 | unet_ = accelerator.unwrap_model(unet, False) |
| 148 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
| 149 | 149 | ||
| 150 | with ema_context(): | 150 | with ema_context(): |
| 151 | pipeline = VlpnStableDiffusion( | 151 | pipeline = VlpnStableDiffusion( |
| @@ -167,8 +167,8 @@ def dreambooth_strategy_callbacks( | |||
| 167 | @torch.no_grad() | 167 | @torch.no_grad() |
| 168 | def on_sample(step): | 168 | def on_sample(step): |
| 169 | with ema_context(): | 169 | with ema_context(): |
| 170 | unet_ = accelerator.unwrap_model(unet) | 170 | unet_ = accelerator.unwrap_model(unet, False) |
| 171 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 171 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
| 172 | 172 | ||
| 173 | orig_unet_dtype = unet_.dtype | 173 | orig_unet_dtype = unet_.dtype |
| 174 | orig_text_encoder_dtype = text_encoder_.dtype | 174 | orig_text_encoder_dtype = text_encoder_.dtype |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 4dd1100..ccec215 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -90,7 +90,7 @@ def lora_strategy_callbacks( | |||
| 90 | def on_checkpoint(step, postfix): | 90 | def on_checkpoint(step, postfix): |
| 91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
| 92 | 92 | ||
| 93 | unet_ = accelerator.unwrap_model(unet) | 93 | unet_ = accelerator.unwrap_model(unet, False) |
| 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") | 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") |
| 95 | del unet_ | 95 | del unet_ |
| 96 | 96 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 0de3cb0..66d3129 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -144,8 +144,8 @@ def textual_inversion_strategy_callbacks( | |||
| 144 | @torch.no_grad() | 144 | @torch.no_grad() |
| 145 | def on_sample(step): | 145 | def on_sample(step): |
| 146 | with ema_context(): | 146 | with ema_context(): |
| 147 | unet_ = accelerator.unwrap_model(unet) | 147 | unet_ = accelerator.unwrap_model(unet, False) |
| 148 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
| 149 | 149 | ||
| 150 | orig_unet_dtype = unet_.dtype | 150 | orig_unet_dtype = unet_.dtype |
| 151 | orig_text_encoder_dtype = text_encoder_.dtype | 151 | orig_text_encoder_dtype = text_encoder_.dtype |
