summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py38
1 files changed, 29 insertions, 9 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 14bdafd..d306f18 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks(
59 save_samples_ = partial( 59 save_samples_ = partial(
60 save_samples, 60 save_samples,
61 accelerator=accelerator, 61 accelerator=accelerator,
62 unet=unet,
63 text_encoder=text_encoder,
64 tokenizer=tokenizer, 62 tokenizer=tokenizer,
65 vae=vae, 63 vae=vae,
66 sample_scheduler=sample_scheduler, 64 sample_scheduler=sample_scheduler,
67 train_dataloader=train_dataloader, 65 train_dataloader=train_dataloader,
68 val_dataloader=val_dataloader, 66 val_dataloader=val_dataloader,
69 dtype=weight_dtype,
70 output_dir=sample_output_dir, 67 output_dir=sample_output_dir,
71 seed=seed, 68 seed=seed,
72 batch_size=sample_batch_size, 69 batch_size=sample_batch_size,
@@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks(
94 else: 91 else:
95 return nullcontext() 92 return nullcontext()
96 93
97 def on_model(): 94 def on_accum_model():
98 return text_encoder.text_model.embeddings.temp_token_embedding 95 return text_encoder.text_model.embeddings.temp_token_embedding
99 96
100 def on_prepare(): 97 def on_prepare():
@@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks(
149 @torch.no_grad() 146 @torch.no_grad()
150 def on_sample(step): 147 def on_sample(step):
151 with ema_context(): 148 with ema_context():
152 save_samples_(step=step) 149 unet_ = accelerator.unwrap_model(unet)
150 text_encoder_ = accelerator.unwrap_model(text_encoder)
151
152 orig_unet_dtype = unet_.dtype
153 orig_text_encoder_dtype = text_encoder_.dtype
154
155 unet_.to(dtype=weight_dtype)
156 text_encoder_.to(dtype=weight_dtype)
157
158 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
159
160 unet_.to(dtype=orig_unet_dtype)
161 text_encoder_.to(dtype=orig_text_encoder_dtype)
162
163 del unet_
164 del text_encoder_
165
166 if torch.cuda.is_available():
167 torch.cuda.empty_cache()
153 168
154 return TrainingCallbacks( 169 return TrainingCallbacks(
155 on_prepare=on_prepare, 170 on_prepare=on_prepare,
156 on_model=on_model, 171 on_accum_model=on_accum_model,
157 on_train=on_train, 172 on_train=on_train,
158 on_eval=on_eval, 173 on_eval=on_eval,
159 on_before_optimize=on_before_optimize, 174 on_before_optimize=on_before_optimize,
@@ -168,7 +183,11 @@ def textual_inversion_prepare(
168 accelerator: Accelerator, 183 accelerator: Accelerator,
169 text_encoder: CLIPTextModel, 184 text_encoder: CLIPTextModel,
170 unet: UNet2DConditionModel, 185 unet: UNet2DConditionModel,
171 *args 186 optimizer: torch.optim.Optimizer,
187 train_dataloader: DataLoader,
188 val_dataloader: Optional[DataLoader],
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs
172): 191):
173 weight_dtype = torch.float32 192 weight_dtype = torch.float32
174 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -176,9 +195,10 @@ def textual_inversion_prepare(
176 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
177 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
178 197
179 prepped = accelerator.prepare(text_encoder, *args) 198 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
199 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler)
180 unet.to(accelerator.device, dtype=weight_dtype) 200 unet.to(accelerator.device, dtype=weight_dtype)
181 return (prepped[0], unet) + prepped[1:] 201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
182 202
183 203
184textual_inversion_strategy = TrainingStrategy( 204textual_inversion_strategy = TrainingStrategy(