from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional, Protocol from functools import partial from pathlib import Path import itertools import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.utils import make_grid import numpy as np from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin from tqdm.auto import tqdm from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler from util.slerp import slerp from util.noise import perlin_noise def const(result=None): def fn(*args, **kwargs): return result return fn @dataclass class TrainingCallbacks(): on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], Any] = const() on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int, int], None] = const() on_checkpoint: Callable[[int, str], None] = const() class TrainingStrategyPrepareCallable(Protocol): def __call__( self, accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ) -> Tuple: ... @dataclass class TrainingStrategy(): callbacks: Callable[..., TrainingCallbacks] prepare: TrainingStrategyPrepareCallable def get_models( pretrained_model_name_or_path: str, emb_alpha: int = 8, emb_dropout: float = 0.0 ): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') embeddings = patch_managed_embeddings(text_encoder, emb_alpha, emb_dropout) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings def save_samples( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: SchedulerMixin, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, step: int, cycle: int = 1, batch_size: int = 1, num_batches: int = 1, num_steps: int = 20, guidance_scale: float = 7.5, image_size: Optional[int] = None, ): print(f"Saving samples for step {step}...") grid_cols = min(batch_size, 4) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) generator = torch.Generator(device=accelerator.device).manual_seed(seed) datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) datasets.append(("val", val_dataloader, None)) for pool, data, gen in datasets: all_samples = [] file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) prompt_ids = [ prompt for batch in batches for prompt in batch["prompt_ids"] ] nprompt_ids = [ prompt for batch in batches for prompt in batch["nprompt_ids"] ] with torch.inference_mode(): for i in range(num_batches): start = i * batch_size end = (i + 1) * batch_size prompt = prompt_ids[start:end] nprompt = nprompt_ids[start:end] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=image_size, width=image_size, generator=gen, guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, output_type=None, ).images all_samples.append(torch.from_numpy(samples)) all_samples = torch.cat(all_samples) for tracker in accelerator.trackers: if tracker.name == "tensorboard": # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] image_grid.save(file_path, quality=85) del generator, pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_class_images( accelerator: Accelerator, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: SchedulerMixin, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int ): missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] if len(missing_data) == 0: return batched_data = [ missing_data[i:i+sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, num_inference_steps=sample_steps ).images for i, image in enumerate(images): image.save(image_name[i]) del pipeline def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Optional[Union[list[int], int]] = None, initializer_noise: float = 0.0, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] if num_vectors is None: num_vectors = [len(ids) for ids in initializer_token_ids] placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) return placeholder_token_ids, initializer_token_ids def compute_snr(timesteps, noise_scheduler): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def get_original( noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor ): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(sample.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(sample.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) if noise_scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - sigma * model_output) / alpha elif noise_scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif noise_scheduler.config.prediction_type == "v_prediction": pred_original_sample = alpha * sample - sigma * model_output else: raise ValueError( f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) return pred_original_sample def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, schedule_sampler: ScheduleSampler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, guidance_scale: float, prior_loss_weight: float, seed: int, offset_noise_strength: float, input_pertubation: float, disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, step: int, batch: dict[str, Any], cache: dict[Any, Any], eval: bool = False, ): images = batch["pixel_values"] generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None bsz = images.shape[0] # Convert images to latent space latents = vae.encode(images).latent_dist.sample(generator=generator) latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn( latents.shape, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator ) applied_noise = noise if offset_noise_strength != 0: applied_noise = applied_noise + offset_noise_strength * perlin_noise( latents.shape, res=1, octaves=4, dtype=latents.dtype, device=latents.device, generator=generator ) if input_pertubation != 0: applied_noise = applied_noise + input_pertubation * torch.randn( latents.shape, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator ) # Sample a random timestep for each image timesteps, weights = schedule_sampler.sample(bsz, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") acc = (model_pred == target).float().mean() if guidance_scale == 0 and prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) if disc is not None: rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) rec_latent = rec_latent / vae.config.scaling_factor rec_latent = rec_latent.to(dtype=vae.dtype) rec = vae.decode(rec_latent, return_dict=False)[0] loss = 1 - disc.get_score(rec) if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) loss = loss * mse_loss_weights if isinstance(schedule_sampler, LossAwareSampler): schedule_sampler.update_with_all_losses(timesteps, loss.detach()) loss = loss * weights loss = loss.mean() return loss, acc, bsz class LossCallable(Protocol): def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, cycle: int = 0, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), avg_loss_val: AverageMeter = AverageMeter(), avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_val_steps_per_epoch = math.ceil( len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 cache = {} best_acc = avg_acc.max best_acc_val = avg_acc_val.max local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) global_progress_bar.set_description("Total progress") on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint isDadaptation = False try: import dadaptation isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) except ImportError: pass num_training_steps += global_step_offset global_step += global_step_offset try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0 and (cycle == 0 or epoch != 0): local_progress_bar.clear() global_progress_bar.clear() with on_eval(): on_sample(cycle, global_step) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() global_progress_bar.clear() on_checkpoint(global_step, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() logs = {} with on_train(cycle): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) loss = loss / gradient_accumulation_steps accelerator.backward(loss) avg_loss.update(loss.item(), bsz) avg_acc.update(acc.item(), bsz) logs = { "train/loss": avg_loss.avg, "train/acc": avg_acc.avg, "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), } lrs: dict[str, float] = {} for i, lr in enumerate(lr_scheduler.get_last_lr()): if torch.is_tensor(lr): lr = lr.item() label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr if isDadaptation: lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] logs[f"d*lr/{label}"] = lr lrs[label] = lr logs.update(on_log()) local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): before_optimize_result = on_before_optimize(epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) on_after_optimize(before_optimize_result, lrs) local_progress_bar.update(1) global_progress_bar.update(1) accelerator.log(logs, step=global_step) global_step += 1 if global_step >= num_training_steps: break accelerator.wait_for_everyone() on_after_epoch() if val_dataloader is not None: cur_loss_val = AverageMeter(power=1) cur_acc_val = AverageMeter(power=1) with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) loss = loss / gradient_accumulation_steps cur_loss_val.update(loss.item(), bsz) cur_acc_val.update(acc.item(), bsz) logs = { "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): local_progress_bar.update(1) global_progress_bar.update(1) avg_loss_val.update(cur_loss_val.avg) avg_acc_val.update(cur_acc_val.avg) logs["val/cur_loss"] = cur_loss_val.avg logs["val/cur_acc"] = cur_acc_val.avg logs["val/loss"] = avg_loss_val.avg logs["val/acc"] = avg_acc_val.avg accelerator.log(logs, step=global_step) if accelerator.is_main_process: if avg_acc_val.max > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.max else: if accelerator.is_main_process: if avg_acc.max > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") on_checkpoint(global_step, "milestone") best_acc = avg_acc.max # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") with on_eval(): on_sample(cycle, global_step) on_checkpoint(global_step, "end") except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") on_checkpoint(global_step, "end") raise KeyboardInterrupt return avg_loss, avg_acc, avg_loss_val, avg_acc_val def train( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, compile_unet: bool = False, no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, cycle: int = 1, global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, offset_noise_strength: float = 0.01, input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, schedule_sampler: Optional[ScheduleSampler] = None, min_snr_gamma: int = 5, avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), avg_loss_val: AverageMeter = AverageMeter(), avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() vae = torch.compile(vae, backend='hidet') if compile_unet: unet = torch.compile(unet, backend='hidet') # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, vae=vae, train_dataloader=train_dataloader, val_dataloader=val_dataloader, seed=seed, **kwargs, ) if schedule_sampler is None: schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) loss_step_ = partial( loss_step, vae, noise_scheduler, schedule_sampler, unet, text_encoder, guidance_scale, prior_loss_weight, seed, offset_noise_strength, input_pertubation, disc, min_snr_gamma, ) train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader if not no_val else None, loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, cycle=cycle, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, group_labels=group_labels, avg_loss=avg_loss, avg_acc=avg_acc, avg_loss_val=avg_loss_val, avg_acc_val=avg_acc_val, callbacks=callbacks, ) accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) accelerator.free_memory()