from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional, Type from functools import partial from pathlib import Path import itertools import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from tqdm.auto import tqdm from PIL import Image from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter def const(result=None): def fn(*args, **kwargs): return result return fn @dataclass class TrainingCallbacks(): on_prepare: Callable[[], None] = const() on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() on_after_optimize: Callable[[float], None] = const() on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() on_checkpoint: Callable[[int, str], None] = const() @dataclass class TrainingStrategy(): callbacks: Callable[..., TrainingCallbacks] prepare_unet: bool = False def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i % cols*w, i//cols*h)) return grid def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') embeddings = patch_managed_embeddings(text_encoder) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings def save_samples( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], dtype: torch.dtype, output_dir: Path, seed: int, step: int, batch_size: int = 1, num_batches: int = 1, num_steps: int = 20, guidance_scale: float = 7.5, image_size: Optional[int] = None, ): print(f"Saving samples for step {step}...") grid_cols = min(batch_size, 4) grid_rows = (num_batches * batch_size) // grid_cols unet = accelerator.unwrap_model(unet) text_encoder = accelerator.unwrap_model(text_encoder) orig_unet_dtype = unet.dtype orig_text_encoder_dtype = text_encoder.dtype unet.to(dtype=dtype) text_encoder.to(dtype=dtype) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) generator = torch.Generator(device=accelerator.device).manual_seed(seed) datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) datasets.append(("val", val_dataloader, None)) for pool, data, gen in datasets: all_samples = [] file_path = output_dir.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) prompt_ids = [ prompt for batch in batches for prompt in batch["prompt_ids"] ] nprompt_ids = [ prompt for batch in batches for prompt in batch["nprompt_ids"] ] for i in range(num_batches): start = i * batch_size end = (i + 1) * batch_size prompt = prompt_ids[start:end] nprompt = nprompt_ids[start:end] samples = pipeline( prompt=prompt, negative_prompt=nprompt, height=image_size, width=image_size, generator=gen, guidance_scale=guidance_scale, num_inference_steps=num_steps, output_type='pil' ).images all_samples += samples image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) unet.to(dtype=orig_unet_dtype) text_encoder.to(dtype=orig_text_encoder_dtype) del unet del text_encoder del generator del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_class_images( accelerator: Accelerator, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: DPMSolverMultistepScheduler, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int ): missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] if len(missing_data) == 0: return batched_data = [ missing_data[i:i+sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=sample_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, num_inference_steps=sample_steps ).images for i, image in enumerate(images): image.save(image_name[i]) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[list[int], int] ): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) for token in initializer_tokens ] placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) embeddings.resize(len(tokenizer)) for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id) return placeholder_token_ids, initializer_token_ids def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, seed: int, step: int, batch: dict[str, Any], eval: bool = False ): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None # Sample noise that we'll add to the latents noise = torch.randn( latents.shape, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator ) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), generator=generator, device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") acc = (model_pred == target).float().mean() return loss, acc, bsz def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], sample_frequency: int = 10, checkpoint_frequency: int = 50, global_step_offset: int = 0, num_epochs: int = 100, callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 avg_loss = AverageMeter() avg_acc = AverageMeter() avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() max_acc = 0.0 max_acc_val = 0.0 local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, dynamic_ncols=True ) global_progress_bar.set_description("Total progress") model = callbacks.on_model() on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() model.train() with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) on_before_optimize(epoch) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: on_after_optimize(lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) global_step += 1 logs = { "train/loss": avg_loss.avg.item(), "train/acc": avg_acc.avg.item(), "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } logs.update(on_log()) accelerator.log(logs, step=global_step) local_progress_bar.set_postfix(**logs) if global_step >= num_training_steps: break accelerator.wait_for_everyone() on_after_epoch(lr_scheduler.get_last_lr()[0]) if val_dataloader is not None: model.eval() cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, True) loss = loss.detach_() acc = acc.detach_() cur_loss_val.update(loss, bsz) cur_acc_val.update(acc, bsz) avg_loss_val.update(loss, bsz) avg_acc_val.update(acc, bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { "val/loss": avg_loss_val.avg.item(), "val/acc": avg_acc_val.avg.item(), "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() accelerator.log(logs, step=global_step) local_progress_bar.clear() global_progress_bar.clear() if accelerator.is_main_process: if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() else: if accelerator.is_main_process: if avg_acc.avg.item() > max_acc: accelerator.print( f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") on_checkpoint(global_step + global_step_offset, "milestone") max_acc = avg_acc.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") on_checkpoint(global_step + global_step_offset, "end") on_sample(global_step + global_step_offset) except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") on_checkpoint(global_step + global_step_offset, "end") raise KeyboardInterrupt def train( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, noise_scheduler: DDPMScheduler, dtype: torch.dtype, seed: int, project: str, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, strategy: TrainingStrategy, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, **kwargs, ): prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] if strategy.prepare_unet: prep.append(unet) prep = accelerator.prepare(*prep) if strategy.prepare_unet: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep else: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep unet.to(accelerator.device, dtype=dtype) vae.to(accelerator.device, dtype=dtype) for model in (unet, text_encoder, vae): model.requires_grad_(False) model.eval() callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, vae=vae, train_dataloader=train_dataloader, val_dataloader=val_dataloader, seed=seed, **kwargs, ) callbacks.on_prepare() loss_step_ = partial( loss_step, vae, noise_scheduler, unet, text_encoder, with_prior_preservation, prior_loss_weight, seed, ) if accelerator.is_main_process: accelerator.init_trackers(project) train_loop( accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, loss_step=loss_step_, sample_frequency=sample_frequency, checkpoint_frequency=checkpoint_frequency, global_step_offset=global_step_offset, num_epochs=num_train_epochs, callbacks=callbacks, ) accelerator.end_training() accelerator.free_memory()