From f39286fa5c5840b67dadf8e85f5f5d7ff1414aab Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 11 Apr 2023 22:36:05 +0200 Subject: Experimental convnext discriminator support --- training/functional.py | 83 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 24 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 2f7f837..be39776 100644 --- a/training/functional.py +++ b/training/functional.py @@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer +from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter from util.slerp import slerp @@ -160,7 +161,8 @@ def save_samples( for tracker in accelerator.trackers: if tracker.name == "tensorboard": - tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") + # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") + pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] @@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler): return snr -def make_solid_image(color: float, shape, vae, dtype, device, generator): - img = torch.tensor( - [[[[color]]]], - dtype=dtype, - device=device - ).expand(1, *shape) - img = img * 2 - 1 - img = vae.encode(img).latent_dist.sample(generator=generator) - img *= vae.config.scaling_factor - return img +def get_original( + noise_scheduler, + model_output, + sample: torch.FloatTensor, + timesteps: torch.IntTensor +): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(sample.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(sample.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) + + if noise_scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - sigma * model_output) / alpha + elif noise_scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif noise_scheduler.config.prediction_type == "v_prediction": + pred_original_sample = alpha * sample - sigma * model_output + else: + raise ValueError( + f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + return pred_original_sample def loss_step( @@ -296,6 +320,7 @@ def loss_step( prior_loss_weight: float, seed: int, offset_noise_strength: float, + disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, step: int, batch: dict[str, Any], @@ -373,23 +398,31 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if guidance_scale == 0 and prior_loss_weight != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) + if disc is None: + if guidance_scale == 0 and prior_loss_weight != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") - # Add the prior loss to the instance loss. - loss = loss + prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + # Add the prior loss to the instance loss. + loss = loss + prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = loss.mean([1, 2, 3]) + else: + rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) + rec_latent /= vae.config.scaling_factor + rec_latent = rec_latent.to(dtype=vae.dtype) + rec = vae.decode(rec_latent).sample + loss = 1 - disc.get_score(rec) + del rec_latent, rec if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) @@ -645,6 +678,7 @@ def train( guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, offset_noise_strength: float = 0.15, + disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, **kwargs, ): @@ -676,6 +710,7 @@ def train( prior_loss_weight, seed, offset_noise_strength, + disc, min_snr_gamma, ) -- cgit v1.2.3-70-g09d2