diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 3 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 8 | ||||
-rw-r--r-- | training/strategy/lora.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 4 |
4 files changed, 8 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py index 78a2b10..41794ea 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
12 | 12 | ||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler |
16 | 16 | ||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
@@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | ||
26 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
27 | 26 | ||
28 | 27 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 8aaed3a..d697554 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -144,8 +144,8 @@ def dreambooth_strategy_callbacks( | |||
144 | 144 | ||
145 | print("Saving model...") | 145 | print("Saving model...") |
146 | 146 | ||
147 | unet_ = accelerator.unwrap_model(unet) | 147 | unet_ = accelerator.unwrap_model(unet, False) |
148 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
149 | 149 | ||
150 | with ema_context(): | 150 | with ema_context(): |
151 | pipeline = VlpnStableDiffusion( | 151 | pipeline = VlpnStableDiffusion( |
@@ -167,8 +167,8 @@ def dreambooth_strategy_callbacks( | |||
167 | @torch.no_grad() | 167 | @torch.no_grad() |
168 | def on_sample(step): | 168 | def on_sample(step): |
169 | with ema_context(): | 169 | with ema_context(): |
170 | unet_ = accelerator.unwrap_model(unet) | 170 | unet_ = accelerator.unwrap_model(unet, False) |
171 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 171 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
172 | 172 | ||
173 | orig_unet_dtype = unet_.dtype | 173 | orig_unet_dtype = unet_.dtype |
174 | orig_text_encoder_dtype = text_encoder_.dtype | 174 | orig_text_encoder_dtype = text_encoder_.dtype |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 4dd1100..ccec215 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -90,7 +90,7 @@ def lora_strategy_callbacks( | |||
90 | def on_checkpoint(step, postfix): | 90 | def on_checkpoint(step, postfix): |
91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
92 | 92 | ||
93 | unet_ = accelerator.unwrap_model(unet) | 93 | unet_ = accelerator.unwrap_model(unet, False) |
94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") | 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") |
95 | del unet_ | 95 | del unet_ |
96 | 96 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 0de3cb0..66d3129 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -144,8 +144,8 @@ def textual_inversion_strategy_callbacks( | |||
144 | @torch.no_grad() | 144 | @torch.no_grad() |
145 | def on_sample(step): | 145 | def on_sample(step): |
146 | with ema_context(): | 146 | with ema_context(): |
147 | unet_ = accelerator.unwrap_model(unet) | 147 | unet_ = accelerator.unwrap_model(unet, False) |
148 | text_encoder_ = accelerator.unwrap_model(text_encoder) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
149 | 149 | ||
150 | orig_unet_dtype = unet_.dtype | 150 | orig_unet_dtype = unet_.dtype |
151 | orig_text_encoder_dtype = text_encoder_.dtype | 151 | orig_text_encoder_dtype = text_encoder_.dtype |