from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional, Protocol from functools import partial from pathlib import Path import itertools import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin from tqdm.auto import tqdm from PIL import Image from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter from util.slerp import slerp def const(result=None): def fn(*args, **kwargs): return result return fn @dataclass class TrainingCallbacks(): on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], Any] = const() on_after_optimize: Callable[[Any, dict[str, float]], None] = const() on_after_epoch: Callable[[], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() on_checkpoint: Callable[[int, str], None] = const() class TrainingStrategyPrepareCallable(Protocol): def __call__( self, accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ) -> Tuple: ... @dataclass class TrainingStrategy(): callbacks: Callable[..., TrainingCallbacks] prepare: TrainingStrategyPrepareCallable def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i % cols*w, i//cols*h)) return grid def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') embeddings = patch_managed_embeddings(text_encoder) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings def save_samples( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: SchedulerMixin, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, step: int, batch_size: int = 1, num_batches: int = 1, num_steps: int = 20, guidance_scale: float = 7.5, image_size: Optional[int] = None, ): print(f"Saving samples for step {step}...") grid_cols = min(batch_size, 4) grid_rows = (num_batches * batch_size) // grid_cols pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) generator = torch.Generator(device=accelerator.device).manual_seed(seed) datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) datasets.append(("val", val_dataloader, None)) for pool, data, gen in datasets: all_samples = [] file_path = output_dir / pool / f"step_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) prompt_ids = [ prompt for batch in batches for prompt in batch["prompt_ids"] ] nprompt_ids = [ prompt for batch in batches for prompt in batch["nprompt_ids"] ] for i in range(num_batches): start = i * batch_size end = (i + 1) * batch_size prompt = prompt_ids[start:end] nprompt = nprompt_ids[start:end] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=image_size, width=image_size, generator=gen, guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, output_type='pil' ).images all_samples += samples image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) del generator del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_class_images( accelerator: Accelerator, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: SchedulerMixin, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int ): missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] if len(missing_data) == 0: return batched_data = [ missing_data[i:i+sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, num_inference_steps=sample_steps ).images for i, image in enumerate(images): image.save(image_name[i]) del pipeline def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Optional[Union[list[int], int]] = None, initializer_noise: float = 0.0, ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] if num_vectors is None: num_vectors = [len(ids) for ids in initializer_token_ids] placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) return placeholder_token_ids, initializer_token_ids def compute_snr(timesteps, noise_scheduler): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def make_solid_image(color: float, shape, vae, dtype, device, generator): img = torch.tensor( [[[[color]]]], dtype=dtype, device=device ).expand(1, *shape) img = img * 2 - 1 img = vae.encode(img).latent_dist.sample(generator=generator) img *= vae.config.scaling_factor return img def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, guidance_scale: float, prior_loss_weight: float, seed: int, offset_noise_strength: float, min_snr_gamma: int, step: int, batch: dict[str, Any], cache: dict[Any, Any], eval: bool = False, ): images = batch["pixel_values"] generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None bsz = images.shape[0] # Convert images to latent space latents = vae.encode(images).latent_dist.sample(generator=generator) latents *= vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn( latents.shape, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator ) if offset_noise_strength != 0: offset_noise = torch.randn( (latents.shape[0], latents.shape[1], 1, 1), dtype=latents.dtype, device=latents.device, generator=generator ).expand(noise.shape) noise += offset_noise_strength * offset_noise # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), generator=generator, device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if guidance_scale == 0 and prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) loss *= mse_loss_weights loss = loss.mean() acc = (model_pred == target).float().mean() return loss, acc, bsz class LossCallable(Protocol): def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, global_step_offset: int = 0, num_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 cache = {} avg_loss = AverageMeter() avg_acc = AverageMeter() avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() best_acc = 0.0 best_acc_val = 0.0 local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) global_progress_bar.set_description("Total progress") on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint isDadaptation = False try: import dadaptation isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) except ImportError: pass try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: local_progress_bar.clear() global_progress_bar.clear() on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() global_progress_bar.clear() on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() logs = {} with on_train(epoch): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) loss /= gradient_accumulation_steps accelerator.backward(loss) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), } lrs: dict[str, float] = {} for i, lr in enumerate(lr_scheduler.get_last_lr()): if torch.is_tensor(lr): lr = lr.item() label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr if isDadaptation: lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] logs[f"d*lr/{label}"] = lr lrs[label] = lr logs.update(on_log()) local_progress_bar.set_postfix(**logs) if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): before_optimize_result = on_before_optimize(epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) on_after_optimize(before_optimize_result, lrs) local_progress_bar.update(1) global_progress_bar.update(1) accelerator.log(logs, step=global_step) global_step += 1 if global_step >= num_training_steps: break accelerator.wait_for_everyone() on_after_epoch() if val_dataloader is not None: cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) loss = loss.detach_() acc = acc.detach_() cur_loss_val.update(loss, bsz) cur_acc_val.update(acc, bsz) avg_loss_val.update(loss, bsz) avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { "val/loss": avg_loss_val.avg.item(), "val/acc": avg_acc_val.avg.item(), "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() accelerator.log(logs, step=global_step) if accelerator.is_main_process: if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") best_acc_val = avg_acc_val.avg.item() else: if accelerator.is_main_process: if avg_acc.avg.item() > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") best_acc = avg_acc.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") on_checkpoint(global_step + global_step_offset, "end") on_sample(global_step + global_step_offset) except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") on_checkpoint(global_step + global_step_offset, "end") raise KeyboardInterrupt def train( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, project: str, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, no_val: bool = False, num_train_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], sample_frequency: int = 20, checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, offset_noise_strength: float = 0.15, min_snr_gamma: int = 5, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) kwargs.update(extra) vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, vae=vae, train_dataloader=train_dataloader, val_dataloader=val_dataloader, seed=seed, **kwargs, ) loss_step_ = partial( loss_step, vae, noise_scheduler, unet, text_encoder, guidance_scale, prior_loss_weight, seed, offset_noise_strength, min_snr_gamma, ) if accelerator.is_main_process: accelerator.init_trackers(project) train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader if not no_val else None, loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, milestone_checkpoints=milestone_checkpoints, global_step_offset=global_step_offset, num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, group_labels=group_labels, callbacks=callbacks, ) accelerator.end_training() accelerator.free_memory()