from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional, Protocol from types import MethodType from functools import partial from pathlib import Path import itertools import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.utils import make_grid import numpy as np from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import ( AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin, ) from tqdm.auto import tqdm from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler def const(result=None): def fn(*args, **kwargs): return result return fn @dataclass class TrainingCallbacks: on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], Any] = const() on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int, int], None] = const() on_checkpoint: Callable[[int, str], None] = const() class TrainingStrategyPrepareCallable(Protocol): def __call__( self, accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs, ) -> Tuple: ... @dataclass class TrainingStrategy: callbacks: Callable[..., TrainingCallbacks] prepare: TrainingStrategyPrepareCallable def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): tokenizer = MultiCLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer" ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype ) vae = AutoencoderKL.from_pretrained( pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype ) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype ) noise_scheduler = DDPMScheduler.from_pretrained( pretrained_model_name_or_path, subfolder="scheduler" ) sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder="scheduler" ) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler def save_samples( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: SchedulerMixin, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, step: int, validation_prompts: list[str] = [], cycle: int = 1, batch_size: int = 1, num_batches: int = 1, num_steps: int = 20, guidance_scale: float = 7.5, image_size: Optional[int] = None, ): print(f"Saving samples for step {step}...") grid_cols = min(batch_size, 4) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) generator = torch.Generator(device=accelerator.device).manual_seed(seed) datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ ("train", train_dataloader, None) ] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) for pool, data, gen in datasets: all_samples = [] file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) batches = list( itertools.islice(itertools.cycle(data), batch_size * num_batches) ) prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] with torch.inference_mode(): for i in range(num_batches): start = i * batch_size end = (i + 1) * batch_size prompt = prompt_ids[start:end] nprompt = nprompt_ids[start:end] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=image_size, width=image_size, generator=gen, guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, ).images all_samples.append(torch.from_numpy(samples)) all_samples = torch.cat(all_samples) for tracker in accelerator.trackers: if tracker.name == "tensorboard": # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) image_grid = pipeline.numpy_to_pil( image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() )[0] image_grid.save(file_path, quality=85) del generator, pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_class_images( accelerator: Accelerator, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: SchedulerMixin, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int, ): missing_data = [ item for item in train_dataset.items if not item.class_image_path.exists() ] if len(missing_data) == 0: return batched_data = [ missing_data[i : i + sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, num_inference_steps=sample_steps, ).images for i, image in enumerate(images): image.save(image_name[i]) del pipeline def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Optional[Union[list[int], int]] = None, initializer_noise: float = 0.0, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] if num_vectors is None: num_vectors = [len(ids) for ids in initializer_token_ids] placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) for placeholder_token_id, initializer_token_id in zip( placeholder_token_ids, initializer_token_ids ): embeddings.add_embed( placeholder_token_id, initializer_token_id, initializer_noise ) return placeholder_token_ids, initializer_token_ids def compute_snr(timesteps, noise_scheduler): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ timesteps ].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( device=timesteps.device )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def get_original( noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor ): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ timesteps ].float() while len(sqrt_alphas_cumprod.shape) < len(sample.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(sample.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( device=timesteps.device )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) if noise_scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - sigma * model_output) / alpha elif noise_scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif noise_scheduler.config.prediction_type == "v_prediction": pred_original_sample = alpha * sample - sigma * model_output else: raise ValueError( f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) return pred_original_sample def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, schedule_sampler: ScheduleSampler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, prior_loss_weight: float, seed: int, input_pertubation: float, min_snr_gamma: int, step: int, batch: dict[str, Any], cache: dict[Any, Any], eval: bool = False, ): images = batch["pixel_values"] generator = ( torch.Generator(device=images.device).manual_seed(seed + step) if eval else None ) bsz = images.shape[0] # Convert images to latent space latents = vae.encode(images).latent_dist.sample(generator=generator) latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn( latents.shape, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator, ) applied_noise = noise if input_pertubation != 0: applied_noise = applied_noise + input_pertubation * torch.randn( latents.shape, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator, ) # Sample a random timestep for each image timesteps, weights = schedule_sampler.sample(bsz, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states, return_dict=False )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError( f"Unknown prediction type {noise_scheduler.config.prediction_type}" ) acc = (model_pred == target).float().mean() if prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss prior_loss = F.mse_loss( model_pred_prior.float(), target_prior.float(), reduction="none" ) # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( dim=1 )[0] / snr ) loss = loss * mse_loss_weights if isinstance(schedule_sampler, LossAwareSampler): schedule_sampler.update_with_all_losses(timesteps, loss.detach()) loss = loss * weights loss = loss.mean() return loss, acc, bsz class LossCallable(Protocol): def __call__( self, step: int, batch: dict[Any, Any], cache: dict[str, Any], eval: bool = False, ) -> Tuple[Any, Any, int]: ... def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, cycle: int = 0, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), avg_loss_val: AverageMeter = AverageMeter(), avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil( len(train_dataloader) / gradient_accumulation_steps ) num_val_steps_per_epoch = ( math.ceil(len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 ) num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 cache = {} best_acc = avg_acc.max best_acc_val = avg_acc_val.max local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True, ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True, ) global_progress_bar.set_description("Total progress") on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint isDadaptation = False try: import dadaptation isDadaptation = isinstance( optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) ) except ImportError: pass num_training_steps += global_step_offset global_step += global_step_offset try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0 and (cycle == 0 or epoch != 0): local_progress_bar.clear() global_progress_bar.clear() with on_eval(): on_sample(cycle, global_step) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() global_progress_bar.clear() on_checkpoint(global_step, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() logs = {} with on_train(cycle): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) loss = loss / gradient_accumulation_steps accelerator.backward(loss) avg_loss.update(loss.item(), bsz) avg_acc.update(acc.item(), bsz) logs = { "train/loss": avg_loss.avg, "train/acc": avg_acc.avg, "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), } lrs: dict[str, float] = {} for i, lr in enumerate(lr_scheduler.get_last_lr()): if torch.is_tensor(lr): lr = lr.item() label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr if isDadaptation: lr = ( optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] ) logs[f"d*lr/{label}"] = lr lrs[label] = lr logs.update(on_log()) local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ( (step + 1) == len(train_dataloader) ): before_optimize_result = on_before_optimize(cycle) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) on_after_optimize(before_optimize_result, lrs) local_progress_bar.update(1) global_progress_bar.update(1) accelerator.log(logs, step=global_step) global_step += 1 if global_step >= num_training_steps: break accelerator.wait_for_everyone() on_after_epoch() if val_dataloader is not None: cur_loss_val = AverageMeter(power=1) cur_acc_val = AverageMeter(power=1) with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) loss = loss / gradient_accumulation_steps cur_loss_val.update(loss.item(), bsz) cur_acc_val.update(acc.item(), bsz) logs = { "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ( (step + 1) == len(val_dataloader) ): local_progress_bar.update(1) global_progress_bar.update(1) avg_loss_val.update(cur_loss_val.avg) avg_acc_val.update(cur_acc_val.avg) logs["val/cur_loss"] = cur_loss_val.avg logs["val/cur_acc"] = cur_acc_val.avg logs["val/loss"] = avg_loss_val.avg logs["val/acc"] = avg_acc_val.avg accelerator.log(logs, step=global_step) if accelerator.is_main_process: if avg_acc_val.max > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" ) on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.max else: if accelerator.is_main_process: if avg_acc.max > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" ) on_checkpoint(global_step, "milestone") best_acc = avg_acc.max # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") with on_eval(): on_sample(cycle, global_step) on_checkpoint(global_step, "end") except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") on_checkpoint(global_step, "end") raise KeyboardInterrupt return avg_loss, avg_acc, avg_loss_val, avg_acc_val def train( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, compile_unet: bool = False, no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, cycle: int = 1, global_step_offset: int = 0, prior_loss_weight: float = 1.0, input_pertubation: float = 0.1, schedule_sampler: Optional[ScheduleSampler] = None, min_snr_gamma: int = 5, avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), avg_loss_val: AverageMeter = AverageMeter(), avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): ( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, ) = strategy.prepare( accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs, ) vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() vae = torch.compile(vae, backend="hidet") if compile_unet: unet = torch.compile(unet, backend="hidet") # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, vae=vae, train_dataloader=train_dataloader, val_dataloader=val_dataloader, seed=seed, **kwargs, ) if schedule_sampler is None: schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) loss_step_ = partial( loss_step, vae, noise_scheduler, schedule_sampler, unet, text_encoder, prior_loss_weight, seed, input_pertubation, min_snr_gamma, ) train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader if not no_val else None, loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, cycle=cycle, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, group_labels=group_labels, avg_loss=avg_loss, avg_acc=avg_acc, avg_loss_val=avg_loss_val, avg_acc_val=avg_acc_val, callbacks=callbacks, ) accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder.forward = MethodType(text_encoder.forward, text_encoder) unet.forward = MethodType(unet.forward, unet) accelerator.free_memory()