diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 108 |
1 files changed, 41 insertions, 67 deletions
diff --git a/training/functional.py b/training/functional.py index cc079ef..8917eb7 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -94,6 +94,8 @@ def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | |||
94 | pretrained_model_name_or_path, subfolder="scheduler" | 94 | pretrained_model_name_or_path, subfolder="scheduler" |
95 | ) | 95 | ) |
96 | 96 | ||
97 | prepare_scheduler_for_custom_training(noise_scheduler, "cuda") | ||
98 | |||
97 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 99 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
98 | 100 | ||
99 | 101 | ||
@@ -273,68 +275,39 @@ def add_placeholder_tokens( | |||
273 | return placeholder_token_ids, initializer_token_ids | 275 | return placeholder_token_ids, initializer_token_ids |
274 | 276 | ||
275 | 277 | ||
276 | def compute_snr(timesteps, noise_scheduler): | 278 | def prepare_scheduler_for_custom_training(noise_scheduler, device): |
277 | """ | 279 | if hasattr(noise_scheduler, "all_snr"): |
278 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 | 280 | return |
279 | """ | 281 | |
280 | alphas_cumprod = noise_scheduler.alphas_cumprod | ||
281 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | ||
282 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | ||
283 | |||
284 | # Expand the tensors. | ||
285 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | ||
286 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ | ||
287 | timesteps | ||
288 | ].float() | ||
289 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | ||
290 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
291 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | ||
292 | |||
293 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( | ||
294 | device=timesteps.device | ||
295 | )[timesteps].float() | ||
296 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | ||
297 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
298 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | ||
299 | |||
300 | # Compute SNR. | ||
301 | snr = (alpha / sigma) ** 2 | ||
302 | return snr | ||
303 | |||
304 | |||
305 | def get_original( | ||
306 | noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor | ||
307 | ): | ||
308 | alphas_cumprod = noise_scheduler.alphas_cumprod | 282 | alphas_cumprod = noise_scheduler.alphas_cumprod |
309 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 283 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) |
310 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 284 | sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) |
311 | 285 | alpha = sqrt_alphas_cumprod | |
312 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ | 286 | sigma = sqrt_one_minus_alphas_cumprod |
313 | timesteps | 287 | all_snr = (alpha / sigma) ** 2 |
314 | ].float() | 288 | |
315 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | 289 | noise_scheduler.all_snr = all_snr.to(device) |
316 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 290 | |
317 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | 291 | |
318 | 292 | def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): | |
319 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( | 293 | snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) |
320 | device=timesteps.device | 294 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) |
321 | )[timesteps].float() | 295 | snr_weight = torch.minimum( |
322 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | 296 | gamma_over_snr, torch.ones_like(gamma_over_snr) |
323 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 297 | ).float() # from paper |
324 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | 298 | loss = loss * snr_weight |
299 | return loss | ||
325 | 300 | ||
326 | if noise_scheduler.config.prediction_type == "epsilon": | 301 | |
327 | pred_original_sample = (sample - sigma * model_output) / alpha | 302 | def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): |
328 | elif noise_scheduler.config.prediction_type == "sample": | 303 | snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size |
329 | pred_original_sample = model_output | 304 | snr_t = torch.minimum( |
330 | elif noise_scheduler.config.prediction_type == "v_prediction": | 305 | snr_t, torch.ones_like(snr_t) * 1000 |
331 | pred_original_sample = alpha * sample - sigma * model_output | 306 | ) # if timestep is 0, snr_t is inf, so limit it to 1000 |
332 | else: | 307 | scale = snr_t / (snr_t + 1) |
333 | raise ValueError( | 308 | |
334 | f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | 309 | loss = loss * scale |
335 | " `v_prediction` for the DDPMScheduler." | 310 | return loss |
336 | ) | ||
337 | return pred_original_sample | ||
338 | 311 | ||
339 | 312 | ||
340 | def loss_step( | 313 | def loss_step( |
@@ -347,6 +320,7 @@ def loss_step( | |||
347 | seed: int, | 320 | seed: int, |
348 | input_pertubation: float, | 321 | input_pertubation: float, |
349 | min_snr_gamma: int, | 322 | min_snr_gamma: int, |
323 | scale_v_pred_loss_like_noise_pred: bool, | ||
350 | step: int, | 324 | step: int, |
351 | batch: dict[str, Any], | 325 | batch: dict[str, Any], |
352 | cache: dict[Any, Any], | 326 | cache: dict[Any, Any], |
@@ -433,14 +407,12 @@ def loss_step( | |||
433 | loss = loss.mean([1, 2, 3]) | 407 | loss = loss.mean([1, 2, 3]) |
434 | 408 | ||
435 | if min_snr_gamma != 0: | 409 | if min_snr_gamma != 0: |
436 | snr = compute_snr(timesteps, noise_scheduler) | 410 | loss = apply_snr_weight(loss, timesteps, noise_scheduler, min_snr_gamma) |
437 | mse_loss_weights = ( | 411 | |
438 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( | 412 | if scale_v_pred_loss_like_noise_pred: |
439 | dim=1 | 413 | loss = scale_v_prediction_loss_like_noise_prediction( |
440 | )[0] | 414 | loss, timesteps, noise_scheduler |
441 | / snr | ||
442 | ) | 415 | ) |
443 | loss = loss * mse_loss_weights | ||
444 | 416 | ||
445 | if isinstance(schedule_sampler, LossAwareSampler): | 417 | if isinstance(schedule_sampler, LossAwareSampler): |
446 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) | 418 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) |
@@ -726,6 +698,7 @@ def train( | |||
726 | input_pertubation: float = 0.1, | 698 | input_pertubation: float = 0.1, |
727 | schedule_sampler: Optional[ScheduleSampler] = None, | 699 | schedule_sampler: Optional[ScheduleSampler] = None, |
728 | min_snr_gamma: int = 5, | 700 | min_snr_gamma: int = 5, |
701 | scale_v_pred_loss_like_noise_pred: bool = True, | ||
729 | avg_loss: AverageMeter = AverageMeter(), | 702 | avg_loss: AverageMeter = AverageMeter(), |
730 | avg_acc: AverageMeter = AverageMeter(), | 703 | avg_acc: AverageMeter = AverageMeter(), |
731 | avg_loss_val: AverageMeter = AverageMeter(), | 704 | avg_loss_val: AverageMeter = AverageMeter(), |
@@ -785,6 +758,7 @@ def train( | |||
785 | seed, | 758 | seed, |
786 | input_pertubation, | 759 | input_pertubation, |
787 | min_snr_gamma, | 760 | min_snr_gamma, |
761 | scale_v_pred_loss_like_noise_pred, | ||
788 | ) | 762 | ) |
789 | 763 | ||
790 | train_loop( | 764 | train_loop( |