summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py83
1 files changed, 59 insertions, 24 deletions
diff --git a/training/functional.py b/training/functional.py
index 2f7f837..be39776 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 23from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
24from models.clip.util import get_extended_embeddings 24from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator
26from training.util import AverageMeter 27from training.util import AverageMeter
27from util.slerp import slerp 28from util.slerp import slerp
28 29
@@ -160,7 +161,8 @@ def save_samples(
160 161
161 for tracker in accelerator.trackers: 162 for tracker in accelerator.trackers:
162 if tracker.name == "tensorboard": 163 if tracker.name == "tensorboard":
163 tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") 164 # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC")
165 pass
164 166
165 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 167 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols)
166 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0]
@@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler):
275 return snr 277 return snr
276 278
277 279
278def make_solid_image(color: float, shape, vae, dtype, device, generator): 280def get_original(
279 img = torch.tensor( 281 noise_scheduler,
280 [[[[color]]]], 282 model_output,
281 dtype=dtype, 283 sample: torch.FloatTensor,
282 device=device 284 timesteps: torch.IntTensor
283 ).expand(1, *shape) 285):
284 img = img * 2 - 1 286 alphas_cumprod = noise_scheduler.alphas_cumprod
285 img = vae.encode(img).latent_dist.sample(generator=generator) 287 sqrt_alphas_cumprod = alphas_cumprod**0.5
286 img *= vae.config.scaling_factor 288 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
287 return img 289
290 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
291 while len(sqrt_alphas_cumprod.shape) < len(sample.shape):
292 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
293 alpha = sqrt_alphas_cumprod.expand(sample.shape)
294
295 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
296 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape):
297 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
298 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape)
299
300 if noise_scheduler.config.prediction_type == "epsilon":
301 pred_original_sample = (sample - sigma * model_output) / alpha
302 elif noise_scheduler.config.prediction_type == "sample":
303 pred_original_sample = model_output
304 elif noise_scheduler.config.prediction_type == "v_prediction":
305 pred_original_sample = alpha * sample - sigma * model_output
306 else:
307 raise ValueError(
308 f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
309 " `v_prediction` for the DDPMScheduler."
310 )
311 return pred_original_sample
288 312
289 313
290def loss_step( 314def loss_step(
@@ -296,6 +320,7 @@ def loss_step(
296 prior_loss_weight: float, 320 prior_loss_weight: float,
297 seed: int, 321 seed: int,
298 offset_noise_strength: float, 322 offset_noise_strength: float,
323 disc: Optional[ConvNeXtDiscriminator],
299 min_snr_gamma: int, 324 min_snr_gamma: int,
300 step: int, 325 step: int,
301 batch: dict[str, Any], 326 batch: dict[str, Any],
@@ -373,23 +398,31 @@ def loss_step(
373 else: 398 else:
374 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 399 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
375 400
376 if guidance_scale == 0 and prior_loss_weight != 0: 401 if disc is None:
377 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 402 if guidance_scale == 0 and prior_loss_weight != 0:
378 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
379 target, target_prior = torch.chunk(target, 2, dim=0) 404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
380 406
381 # Compute instance loss 407 # Compute instance loss
382 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
383 409
384 # Compute prior loss 410 # Compute prior loss
385 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
386 412
387 # Add the prior loss to the instance loss. 413 # Add the prior loss to the instance loss.
388 loss = loss + prior_loss_weight * prior_loss 414 loss = loss + prior_loss_weight * prior_loss
389 else: 415 else:
390 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
391 417
392 loss = loss.mean([1, 2, 3]) 418 loss = loss.mean([1, 2, 3])
419 else:
420 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
421 rec_latent /= vae.config.scaling_factor
422 rec_latent = rec_latent.to(dtype=vae.dtype)
423 rec = vae.decode(rec_latent).sample
424 loss = 1 - disc.get_score(rec)
425 del rec_latent, rec
393 426
394 if min_snr_gamma != 0: 427 if min_snr_gamma != 0:
395 snr = compute_snr(timesteps, noise_scheduler) 428 snr = compute_snr(timesteps, noise_scheduler)
@@ -645,6 +678,7 @@ def train(
645 guidance_scale: float = 0.0, 678 guidance_scale: float = 0.0,
646 prior_loss_weight: float = 1.0, 679 prior_loss_weight: float = 1.0,
647 offset_noise_strength: float = 0.15, 680 offset_noise_strength: float = 0.15,
681 disc: Optional[ConvNeXtDiscriminator] = None,
648 min_snr_gamma: int = 5, 682 min_snr_gamma: int = 5,
649 **kwargs, 683 **kwargs,
650): 684):
@@ -676,6 +710,7 @@ def train(
676 prior_loss_weight, 710 prior_loss_weight,
677 seed, 711 seed,
678 offset_noise_strength, 712 offset_noise_strength,
713 disc,
679 min_snr_gamma, 714 min_snr_gamma,
680 ) 715 )
681 716