From 881e20dd52cb68dbe6b8f0a78c82a4ffcf3dea6d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 20:33:23 +0200 Subject: New loss scaling --- training/functional.py | 108 +++++++++++++++++++------------------------------ 1 file changed, 41 insertions(+), 67 deletions(-) diff --git a/training/functional.py b/training/functional.py index cc079ef..8917eb7 100644 --- a/training/functional.py +++ b/training/functional.py @@ -94,6 +94,8 @@ def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): pretrained_model_name_or_path, subfolder="scheduler" ) + prepare_scheduler_for_custom_training(noise_scheduler, "cuda") + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler @@ -273,68 +275,39 @@ def add_placeholder_tokens( return placeholder_token_ids, initializer_token_ids -def compute_snr(timesteps, noise_scheduler): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ - timesteps - ].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( - device=timesteps.device - )[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - -def get_original( - noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor -): +def prepare_scheduler_for_custom_training(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ - timesteps - ].float() - while len(sqrt_alphas_cumprod.shape) < len(sample.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(sample.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( - device=timesteps.device - )[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + + noise_scheduler.all_snr = all_snr.to(device) + + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum( + gamma_over_snr, torch.ones_like(gamma_over_snr) + ).float() # from paper + loss = loss * snr_weight + return loss - if noise_scheduler.config.prediction_type == "epsilon": - pred_original_sample = (sample - sigma * model_output) / alpha - elif noise_scheduler.config.prediction_type == "sample": - pred_original_sample = model_output - elif noise_scheduler.config.prediction_type == "v_prediction": - pred_original_sample = alpha * sample - sigma * model_output - else: - raise ValueError( - f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" - " `v_prediction` for the DDPMScheduler." - ) - return pred_original_sample + +def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum( + snr_t, torch.ones_like(snr_t) * 1000 + ) # if timestep is 0, snr_t is inf, so limit it to 1000 + scale = snr_t / (snr_t + 1) + + loss = loss * scale + return loss def loss_step( @@ -347,6 +320,7 @@ def loss_step( seed: int, input_pertubation: float, min_snr_gamma: int, + scale_v_pred_loss_like_noise_pred: bool, step: int, batch: dict[str, Any], cache: dict[Any, Any], @@ -433,14 +407,12 @@ def loss_step( loss = loss.mean([1, 2, 3]) if min_snr_gamma != 0: - snr = compute_snr(timesteps, noise_scheduler) - mse_loss_weights = ( - torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] - / snr + loss = apply_snr_weight(loss, timesteps, noise_scheduler, min_snr_gamma) + + if scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction( + loss, timesteps, noise_scheduler ) - loss = loss * mse_loss_weights if isinstance(schedule_sampler, LossAwareSampler): schedule_sampler.update_with_all_losses(timesteps, loss.detach()) @@ -726,6 +698,7 @@ def train( input_pertubation: float = 0.1, schedule_sampler: Optional[ScheduleSampler] = None, min_snr_gamma: int = 5, + scale_v_pred_loss_like_noise_pred: bool = True, avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), avg_loss_val: AverageMeter = AverageMeter(), @@ -785,6 +758,7 @@ def train( seed, input_pertubation, min_snr_gamma, + scale_v_pred_loss_like_noise_pred, ) train_loop( -- cgit v1.2.3-70-g09d2