summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py35
-rw-r--r--training/strategy/dreambooth.py7
-rw-r--r--training/strategy/lora.py6
-rw-r--r--training/strategy/ti.py3
4 files changed, 22 insertions, 29 deletions
diff --git a/training/functional.py b/training/functional.py
index be39776..ed8ae3a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -168,8 +168,7 @@ def save_samples(
168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0]
169 image_grid.save(file_path, quality=85) 169 image_grid.save(file_path, quality=85)
170 170
171 del generator 171 del generator, pipeline
172 del pipeline
173 172
174 if torch.cuda.is_available(): 173 if torch.cuda.is_available():
175 torch.cuda.empty_cache() 174 torch.cuda.empty_cache()
@@ -398,31 +397,32 @@ def loss_step(
398 else: 397 else:
399 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 398 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
400 399
401 if disc is None: 400 acc = (model_pred == target).float().mean()
402 if guidance_scale == 0 and prior_loss_weight != 0:
403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
406 401
407 # Compute instance loss 402 if guidance_scale == 0 and prior_loss_weight != 0:
408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
409 406
410 # Compute prior loss 407 # Compute instance loss
411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
412 409
413 # Add the prior loss to the instance loss. 410 # Compute prior loss
414 loss = loss + prior_loss_weight * prior_loss 411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
415 else:
416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
417 412
418 loss = loss.mean([1, 2, 3]) 413 # Add the prior loss to the instance loss.
414 loss = loss + prior_loss_weight * prior_loss
419 else: 415 else:
416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
417
418 loss = loss.mean([1, 2, 3])
419
420 if disc is not None:
420 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) 421 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
421 rec_latent /= vae.config.scaling_factor 422 rec_latent /= vae.config.scaling_factor
422 rec_latent = rec_latent.to(dtype=vae.dtype) 423 rec_latent = rec_latent.to(dtype=vae.dtype)
423 rec = vae.decode(rec_latent).sample 424 rec = vae.decode(rec_latent).sample
424 loss = 1 - disc.get_score(rec) 425 loss = 1 - disc.get_score(rec)
425 del rec_latent, rec
426 426
427 if min_snr_gamma != 0: 427 if min_snr_gamma != 0:
428 snr = compute_snr(timesteps, noise_scheduler) 428 snr = compute_snr(timesteps, noise_scheduler)
@@ -432,7 +432,6 @@ def loss_step(
432 loss *= mse_loss_weights 432 loss *= mse_loss_weights
433 433
434 loss = loss.mean() 434 loss = loss.mean()
435 acc = (model_pred == target).float().mean()
436 435
437 return loss, acc, bsz 436 return loss, acc, bsz
438 437
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index fa51bc7..4ae28b7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -142,9 +142,7 @@ def dreambooth_strategy_callbacks(
142 ) 142 )
143 pipeline.save_pretrained(checkpoint_output_dir) 143 pipeline.save_pretrained(checkpoint_output_dir)
144 144
145 del unet_ 145 del unet_, text_encoder_, pipeline
146 del text_encoder_
147 del pipeline
148 146
149 if torch.cuda.is_available(): 147 if torch.cuda.is_available():
150 torch.cuda.empty_cache() 148 torch.cuda.empty_cache()
@@ -165,8 +163,7 @@ def dreambooth_strategy_callbacks(
165 unet_.to(dtype=orig_unet_dtype) 163 unet_.to(dtype=orig_unet_dtype)
166 text_encoder_.to(dtype=orig_text_encoder_dtype) 164 text_encoder_.to(dtype=orig_text_encoder_dtype)
167 165
168 del unet_ 166 del unet_, text_encoder_
169 del text_encoder_
170 167
171 if torch.cuda.is_available(): 168 if torch.cuda.is_available():
172 torch.cuda.empty_cache() 169 torch.cuda.empty_cache()
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 73ec8f2..1517ee8 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -140,8 +140,7 @@ def lora_strategy_callbacks(
140 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 140 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
141 json.dump(lora_config, f) 141 json.dump(lora_config, f)
142 142
143 del unet_ 143 del unet_, text_encoder_
144 del text_encoder_
145 144
146 if torch.cuda.is_available(): 145 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 146 torch.cuda.empty_cache()
@@ -153,8 +152,7 @@ def lora_strategy_callbacks(
153 152
154 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 153 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
155 154
156 del unet_ 155 del unet_, text_encoder_
157 del text_encoder_
158 156
159 if torch.cuda.is_available(): 157 if torch.cuda.is_available():
160 torch.cuda.empty_cache() 158 torch.cuda.empty_cache()
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 08af89d..ca7cc3d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -158,8 +158,7 @@ def textual_inversion_strategy_callbacks(
158 unet_.to(dtype=orig_unet_dtype) 158 unet_.to(dtype=orig_unet_dtype)
159 text_encoder_.to(dtype=orig_text_encoder_dtype) 159 text_encoder_.to(dtype=orig_text_encoder_dtype)
160 160
161 del unet_ 161 del unet_, text_encoder_
162 del text_encoder_
163 162
164 if torch.cuda.is_available(): 163 if torch.cuda.is_available():
165 torch.cuda.empty_cache() 164 torch.cuda.empty_cache()