From f39286fa5c5840b67dadf8e85f5f5d7ff1414aab Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 11 Apr 2023 22:36:05 +0200 Subject: Experimental convnext discriminator support --- models/convnext/discriminator.py | 35 +++++++++ .../stable_diffusion/vlpn_stable_diffusion.py | 2 +- train_ti.py | 14 ++++ training/functional.py | 83 +++++++++++++++------- 4 files changed, 109 insertions(+), 25 deletions(-) create mode 100644 models/convnext/discriminator.py diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py new file mode 100644 index 0000000..7dbbe3a --- /dev/null +++ b/models/convnext/discriminator.py @@ -0,0 +1,35 @@ +import torch +from timm.models import ConvNeXt +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +from torch.nn import functional as F + + +class ConvNeXtDiscriminator(): + def __init__(self, model: ConvNeXt, input_size: int) -> None: + self.net = model + + self.input_size = input_size + + self.img_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).view(1, -1, 1, 1) + self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1) + + def get_score(self, img): + img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) + img_std = self.img_std.to(device=img.device, dtype=img.dtype) + + img = ((img+1.)/2.).sub(img_mean).div(img_std) + + img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) + pred = self.net(img) + return torch.softmax(pred, dim=-1)[:, 1] + + def get_all(self, img): + img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) + img_std = self.img_std.to(device=img.device, dtype=img.dtype) + + img = ((img + 1.) / 2.).sub(img_mean).div(img_std) + + img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) + pred = self.net(img) + return pred diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cfc3208..13ea2ac 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -477,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline): # 2. Define call parameters batch_size = len(prompt) device = self.execution_device - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels do_classifier_free_guidance = guidance_scale > 1.0 do_self_attention_guidance = sag_scale > 0.0 prep_from_image = isinstance(image, PIL.Image.Image) diff --git a/train_ti.py b/train_ti.py index d7878cd..082e9b7 100644 --- a/train_ti.py +++ b/train_ti.py @@ -13,10 +13,12 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify +from timm.models import create_model import transformers from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter +from models.convnext.discriminator import ConvNeXtDiscriminator from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler @@ -661,6 +663,17 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + convnext = create_model( + "convnext_tiny", + pretrained=False, + num_classes=3, + drop_path_rate=0.0, + ) + convnext.to(accelerator.device, dtype=weight_dtype) + convnext.requires_grad_(False) + convnext.eval() + disc = ConvNeXtDiscriminator(convnext, input_size=384) + if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] alias_initializer_tokens = args.alias_tokens[1::2] @@ -802,6 +815,7 @@ def main(): milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, offset_noise_strength=args.offset_noise_strength, + disc=disc, # -- use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, diff --git a/training/functional.py b/training/functional.py index 2f7f837..be39776 100644 --- a/training/functional.py +++ b/training/functional.py @@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer +from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter from util.slerp import slerp @@ -160,7 +161,8 @@ def save_samples( for tracker in accelerator.trackers: if tracker.name == "tensorboard": - tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") + # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") + pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] @@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler): return snr -def make_solid_image(color: float, shape, vae, dtype, device, generator): - img = torch.tensor( - [[[[color]]]], - dtype=dtype, - device=device - ).expand(1, *shape) - img = img * 2 - 1 - img = vae.encode(img).latent_dist.sample(generator=generator) - img *= vae.config.scaling_factor - return img +def get_original( + noise_scheduler, + model_output, + sample: torch.FloatTensor, + timesteps: torch.IntTensor +): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(sample.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(sample.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) + + if noise_scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - sigma * model_output) / alpha + elif noise_scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif noise_scheduler.config.prediction_type == "v_prediction": + pred_original_sample = alpha * sample - sigma * model_output + else: + raise ValueError( + f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + return pred_original_sample def loss_step( @@ -296,6 +320,7 @@ def loss_step( prior_loss_weight: float, seed: int, offset_noise_strength: float, + disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, step: int, batch: dict[str, Any], @@ -373,23 +398,31 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if guidance_scale == 0 and prior_loss_weight != 0: - # Chunk the noise and model_pred into two parts and compute the loss on each part separately. - model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) - target, target_prior = torch.chunk(target, 2, dim=0) + if disc is None: + if guidance_scale == 0 and prior_loss_weight != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") - # Add the prior loss to the instance loss. - loss = loss + prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + # Add the prior loss to the instance loss. + loss = loss + prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = loss.mean([1, 2, 3]) + else: + rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) + rec_latent /= vae.config.scaling_factor + rec_latent = rec_latent.to(dtype=vae.dtype) + rec = vae.decode(rec_latent).sample + loss = 1 - disc.get_score(rec) + del rec_latent, rec if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) @@ -645,6 +678,7 @@ def train( guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, offset_noise_strength: float = 0.15, + disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, **kwargs, ): @@ -676,6 +710,7 @@ def train( prior_loss_weight, seed, offset_noise_strength, + disc, min_snr_gamma, ) -- cgit v1.2.3-70-g09d2