summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py3
-rw-r--r--training/strategy/dreambooth.py8
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py4
4 files changed, 8 insertions, 9 deletions
diff --git a/training/functional.py b/training/functional.py
index 78a2b10..41794ea 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
26from training.util import AverageMeter 25from training.util import AverageMeter
27 26
28 27
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 8aaed3a..d697554 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -144,8 +144,8 @@ def dreambooth_strategy_callbacks(
144 144
145 print("Saving model...") 145 print("Saving model...")
146 146
147 unet_ = accelerator.unwrap_model(unet) 147 unet_ = accelerator.unwrap_model(unet, False)
148 text_encoder_ = accelerator.unwrap_model(text_encoder) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
149 149
150 with ema_context(): 150 with ema_context():
151 pipeline = VlpnStableDiffusion( 151 pipeline = VlpnStableDiffusion(
@@ -167,8 +167,8 @@ def dreambooth_strategy_callbacks(
167 @torch.no_grad() 167 @torch.no_grad()
168 def on_sample(step): 168 def on_sample(step):
169 with ema_context(): 169 with ema_context():
170 unet_ = accelerator.unwrap_model(unet) 170 unet_ = accelerator.unwrap_model(unet, False)
171 text_encoder_ = accelerator.unwrap_model(text_encoder) 171 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
172 172
173 orig_unet_dtype = unet_.dtype 173 orig_unet_dtype = unet_.dtype
174 orig_text_encoder_dtype = text_encoder_.dtype 174 orig_text_encoder_dtype = text_encoder_.dtype
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 4dd1100..ccec215 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -90,7 +90,7 @@ def lora_strategy_callbacks(
90 def on_checkpoint(step, postfix): 90 def on_checkpoint(step, postfix):
91 print(f"Saving checkpoint for step {step}...") 91 print(f"Saving checkpoint for step {step}...")
92 92
93 unet_ = accelerator.unwrap_model(unet) 93 unet_ = accelerator.unwrap_model(unet, False)
94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") 94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}")
95 del unet_ 95 del unet_
96 96
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 0de3cb0..66d3129 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -144,8 +144,8 @@ def textual_inversion_strategy_callbacks(
144 @torch.no_grad() 144 @torch.no_grad()
145 def on_sample(step): 145 def on_sample(step):
146 with ema_context(): 146 with ema_context():
147 unet_ = accelerator.unwrap_model(unet) 147 unet_ = accelerator.unwrap_model(unet, False)
148 text_encoder_ = accelerator.unwrap_model(text_encoder) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
149 149
150 orig_unet_dtype = unet_.dtype 150 orig_unet_dtype = unet_.dtype
151 orig_text_encoder_dtype = text_encoder_.dtype 151 orig_text_encoder_dtype = text_encoder_.dtype