import math import torch import torch.nn.functional as F from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from training.optimization import get_one_cycle_schedule def get_scheduler( id: str, min_lr: float, lr: float, warmup_func: str, annealing_func: str, warmup_exp: int, annealing_exp: int, cycles: int, warmup_epochs: int, optimizer: torch.optim.Optimizer, max_train_steps: int, num_update_steps_per_epoch: int, gradient_accumulation_steps: int, ): warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps if id == "one_cycle": min_lr = 0.04 if min_lr is None else min_lr / lr lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=max_train_steps * gradient_accumulation_steps, warmup=warmup_func, annealing=annealing_func, warmup_exp=warmup_exp, annealing_exp=annealing_exp, min_lr=min_lr, ) elif id == "cosine_with_restarts": cycles = cycles if cycles is not None else math.ceil( math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_train_steps * gradient_accumulation_steps, num_cycles=cycles, ) else: lr_scheduler = get_scheduler_( id, optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_train_steps * gradient_accumulation_steps, ) return lr_scheduler def generate_class_images( accelerator, text_encoder, vae, unet, tokenizer, scheduler, data_train, sample_batch_size, sample_image_size, sample_steps ): missing_data = [item for item in data_train if not item.class_image_path.exists()] if len(missing_data) != 0: batched_data = [ missing_data[i:i+sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, scheduler=scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, num_inference_steps=sample_steps ).images for i, image in enumerate(images): image.save(image_name[i]) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, prompt_processor, num_class_images: int, prior_loss_weight: float, seed: int, step: int, batch, eval: bool = False ): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), generator=timesteps_gen, device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = prompt_processor.get_embeddings( batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if num_class_images != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") acc = (model_pred == target).float().mean() return loss, acc, bsz