diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 98 |
1 files changed, 36 insertions, 62 deletions
diff --git a/training/functional.py b/training/functional.py index cc079ef..8917eb7 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -94,6 +94,8 @@ def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | |||
| 94 | pretrained_model_name_or_path, subfolder="scheduler" | 94 | pretrained_model_name_or_path, subfolder="scheduler" |
| 95 | ) | 95 | ) |
| 96 | 96 | ||
| 97 | prepare_scheduler_for_custom_training(noise_scheduler, "cuda") | ||
| 98 | |||
| 97 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 99 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
| 98 | 100 | ||
| 99 | 101 | ||
| @@ -273,68 +275,39 @@ def add_placeholder_tokens( | |||
| 273 | return placeholder_token_ids, initializer_token_ids | 275 | return placeholder_token_ids, initializer_token_ids |
| 274 | 276 | ||
| 275 | 277 | ||
| 276 | def compute_snr(timesteps, noise_scheduler): | 278 | def prepare_scheduler_for_custom_training(noise_scheduler, device): |
| 277 | """ | 279 | if hasattr(noise_scheduler, "all_snr"): |
| 278 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 | 280 | return |
| 279 | """ | ||
| 280 | alphas_cumprod = noise_scheduler.alphas_cumprod | ||
| 281 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | ||
| 282 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | ||
| 283 | |||
| 284 | # Expand the tensors. | ||
| 285 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | ||
| 286 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ | ||
| 287 | timesteps | ||
| 288 | ].float() | ||
| 289 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | ||
| 290 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
| 291 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | ||
| 292 | 281 | ||
| 293 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( | 282 | alphas_cumprod = noise_scheduler.alphas_cumprod |
| 294 | device=timesteps.device | 283 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
| 295 | )[timesteps].float() | 284 | sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) |
| 296 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | 285 | alpha = sqrt_alphas_cumprod |
| 297 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 286 | sigma = sqrt_one_minus_alphas_cumprod |
| 298 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | 287 | all_snr = (alpha / sigma) ** 2 |
| 299 | 288 | ||
| 300 | # Compute SNR. | 289 | noise_scheduler.all_snr = all_snr.to(device) |
| 301 | snr = (alpha / sigma) ** 2 | ||
| 302 | return snr | ||
| 303 | 290 | ||
| 304 | 291 | ||
| 305 | def get_original( | 292 | def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): |
| 306 | noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor | 293 | snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) |
| 307 | ): | 294 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) |
| 308 | alphas_cumprod = noise_scheduler.alphas_cumprod | 295 | snr_weight = torch.minimum( |
| 309 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 296 | gamma_over_snr, torch.ones_like(gamma_over_snr) |
| 310 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 297 | ).float() # from paper |
| 298 | loss = loss * snr_weight | ||
| 299 | return loss | ||
| 311 | 300 | ||
| 312 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ | ||
| 313 | timesteps | ||
| 314 | ].float() | ||
| 315 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | ||
| 316 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
| 317 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | ||
| 318 | 301 | ||
| 319 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( | 302 | def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): |
| 320 | device=timesteps.device | 303 | snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size |
| 321 | )[timesteps].float() | 304 | snr_t = torch.minimum( |
| 322 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | 305 | snr_t, torch.ones_like(snr_t) * 1000 |
| 323 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 306 | ) # if timestep is 0, snr_t is inf, so limit it to 1000 |
| 324 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | 307 | scale = snr_t / (snr_t + 1) |
| 325 | 308 | ||
| 326 | if noise_scheduler.config.prediction_type == "epsilon": | 309 | loss = loss * scale |
| 327 | pred_original_sample = (sample - sigma * model_output) / alpha | 310 | return loss |
| 328 | elif noise_scheduler.config.prediction_type == "sample": | ||
| 329 | pred_original_sample = model_output | ||
| 330 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 331 | pred_original_sample = alpha * sample - sigma * model_output | ||
| 332 | else: | ||
| 333 | raise ValueError( | ||
| 334 | f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | ||
| 335 | " `v_prediction` for the DDPMScheduler." | ||
| 336 | ) | ||
| 337 | return pred_original_sample | ||
| 338 | 311 | ||
| 339 | 312 | ||
| 340 | def loss_step( | 313 | def loss_step( |
| @@ -347,6 +320,7 @@ def loss_step( | |||
| 347 | seed: int, | 320 | seed: int, |
| 348 | input_pertubation: float, | 321 | input_pertubation: float, |
| 349 | min_snr_gamma: int, | 322 | min_snr_gamma: int, |
| 323 | scale_v_pred_loss_like_noise_pred: bool, | ||
| 350 | step: int, | 324 | step: int, |
| 351 | batch: dict[str, Any], | 325 | batch: dict[str, Any], |
| 352 | cache: dict[Any, Any], | 326 | cache: dict[Any, Any], |
| @@ -433,14 +407,12 @@ def loss_step( | |||
| 433 | loss = loss.mean([1, 2, 3]) | 407 | loss = loss.mean([1, 2, 3]) |
| 434 | 408 | ||
| 435 | if min_snr_gamma != 0: | 409 | if min_snr_gamma != 0: |
| 436 | snr = compute_snr(timesteps, noise_scheduler) | 410 | loss = apply_snr_weight(loss, timesteps, noise_scheduler, min_snr_gamma) |
| 437 | mse_loss_weights = ( | 411 | |
| 438 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( | 412 | if scale_v_pred_loss_like_noise_pred: |
| 439 | dim=1 | 413 | loss = scale_v_prediction_loss_like_noise_prediction( |
| 440 | )[0] | 414 | loss, timesteps, noise_scheduler |
| 441 | / snr | ||
| 442 | ) | 415 | ) |
| 443 | loss = loss * mse_loss_weights | ||
| 444 | 416 | ||
| 445 | if isinstance(schedule_sampler, LossAwareSampler): | 417 | if isinstance(schedule_sampler, LossAwareSampler): |
| 446 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) | 418 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) |
| @@ -726,6 +698,7 @@ def train( | |||
| 726 | input_pertubation: float = 0.1, | 698 | input_pertubation: float = 0.1, |
| 727 | schedule_sampler: Optional[ScheduleSampler] = None, | 699 | schedule_sampler: Optional[ScheduleSampler] = None, |
| 728 | min_snr_gamma: int = 5, | 700 | min_snr_gamma: int = 5, |
| 701 | scale_v_pred_loss_like_noise_pred: bool = True, | ||
| 729 | avg_loss: AverageMeter = AverageMeter(), | 702 | avg_loss: AverageMeter = AverageMeter(), |
| 730 | avg_acc: AverageMeter = AverageMeter(), | 703 | avg_acc: AverageMeter = AverageMeter(), |
| 731 | avg_loss_val: AverageMeter = AverageMeter(), | 704 | avg_loss_val: AverageMeter = AverageMeter(), |
| @@ -785,6 +758,7 @@ def train( | |||
| 785 | seed, | 758 | seed, |
| 786 | input_pertubation, | 759 | input_pertubation, |
| 787 | min_snr_gamma, | 760 | min_snr_gamma, |
| 761 | scale_v_pred_loss_like_noise_pred, | ||
| 788 | ) | 762 | ) |
| 789 | 763 | ||
| 790 | train_loop( | 764 | train_loop( |
