diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/training/functional.py b/training/functional.py index 38dd59f..e7e1eb3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -324,6 +324,7 @@ def loss_step( | |||
324 | prior_loss_weight: float, | 324 | prior_loss_weight: float, |
325 | seed: int, | 325 | seed: int, |
326 | offset_noise_strength: float, | 326 | offset_noise_strength: float, |
327 | input_pertubation: float, | ||
327 | disc: Optional[ConvNeXtDiscriminator], | 328 | disc: Optional[ConvNeXtDiscriminator], |
328 | min_snr_gamma: int, | 329 | min_snr_gamma: int, |
329 | step: int, | 330 | step: int, |
@@ -337,7 +338,7 @@ def loss_step( | |||
337 | 338 | ||
338 | # Convert images to latent space | 339 | # Convert images to latent space |
339 | latents = vae.encode(images).latent_dist.sample(generator=generator) | 340 | latents = vae.encode(images).latent_dist.sample(generator=generator) |
340 | latents *= vae.config.scaling_factor | 341 | latents = latents * vae.config.scaling_factor |
341 | 342 | ||
342 | # Sample noise that we'll add to the latents | 343 | # Sample noise that we'll add to the latents |
343 | noise = torch.randn( | 344 | noise = torch.randn( |
@@ -355,7 +356,10 @@ def loss_step( | |||
355 | device=latents.device, | 356 | device=latents.device, |
356 | generator=generator | 357 | generator=generator |
357 | ).expand(noise.shape) | 358 | ).expand(noise.shape) |
358 | noise += offset_noise_strength * offset_noise | 359 | noise = noise + offset_noise_strength * offset_noise |
360 | |||
361 | if input_pertubation != 0: | ||
362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | ||
359 | 363 | ||
360 | # Sample a random timestep for each image | 364 | # Sample a random timestep for each image |
361 | timesteps = torch.randint( | 365 | timesteps = torch.randint( |
@@ -369,7 +373,10 @@ def loss_step( | |||
369 | 373 | ||
370 | # Add noise to the latents according to the noise magnitude at each timestep | 374 | # Add noise to the latents according to the noise magnitude at each timestep |
371 | # (this is the forward diffusion process) | 375 | # (this is the forward diffusion process) |
372 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 376 | if input_pertubation != 0: |
377 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) | ||
378 | else: | ||
379 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
373 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 380 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
374 | 381 | ||
375 | # Get the text embedding for conditioning | 382 | # Get the text embedding for conditioning |
@@ -381,7 +388,7 @@ def loss_step( | |||
381 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 388 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
382 | 389 | ||
383 | # Predict the noise residual | 390 | # Predict the noise residual |
384 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 391 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
385 | 392 | ||
386 | if guidance_scale != 0: | 393 | if guidance_scale != 0: |
387 | uncond_encoder_hidden_states = get_extended_embeddings( | 394 | uncond_encoder_hidden_states = get_extended_embeddings( |
@@ -391,7 +398,7 @@ def loss_step( | |||
391 | ) | 398 | ) |
392 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 399 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
393 | 400 | ||
394 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 401 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] |
395 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 402 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) |
396 | 403 | ||
397 | # Get the target for loss depending on the prediction type | 404 | # Get the target for loss depending on the prediction type |
@@ -424,9 +431,9 @@ def loss_step( | |||
424 | 431 | ||
425 | if disc is not None: | 432 | if disc is not None: |
426 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | 433 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) |
427 | rec_latent /= vae.config.scaling_factor | 434 | rec_latent = rec_latent / vae.config.scaling_factor |
428 | rec_latent = rec_latent.to(dtype=vae.dtype) | 435 | rec_latent = rec_latent.to(dtype=vae.dtype) |
429 | rec = vae.decode(rec_latent).sample | 436 | rec = vae.decode(rec_latent, return_dict=False)[0] |
430 | loss = 1 - disc.get_score(rec) | 437 | loss = 1 - disc.get_score(rec) |
431 | 438 | ||
432 | if min_snr_gamma != 0: | 439 | if min_snr_gamma != 0: |
@@ -434,7 +441,7 @@ def loss_step( | |||
434 | mse_loss_weights = ( | 441 | mse_loss_weights = ( |
435 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 442 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr |
436 | ) | 443 | ) |
437 | loss *= mse_loss_weights | 444 | loss = loss * mse_loss_weights |
438 | 445 | ||
439 | loss = loss.mean() | 446 | loss = loss.mean() |
440 | 447 | ||
@@ -539,7 +546,7 @@ def train_loop( | |||
539 | with on_train(cycle): | 546 | with on_train(cycle): |
540 | for step, batch in enumerate(train_dataloader): | 547 | for step, batch in enumerate(train_dataloader): |
541 | loss, acc, bsz = loss_step(step, batch, cache) | 548 | loss, acc, bsz = loss_step(step, batch, cache) |
542 | loss /= gradient_accumulation_steps | 549 | loss = loss / gradient_accumulation_steps |
543 | 550 | ||
544 | accelerator.backward(loss) | 551 | accelerator.backward(loss) |
545 | 552 | ||
@@ -598,7 +605,7 @@ def train_loop( | |||
598 | with torch.inference_mode(), on_eval(): | 605 | with torch.inference_mode(), on_eval(): |
599 | for step, batch in enumerate(val_dataloader): | 606 | for step, batch in enumerate(val_dataloader): |
600 | loss, acc, bsz = loss_step(step, batch, cache, True) | 607 | loss, acc, bsz = loss_step(step, batch, cache, True) |
601 | loss /= gradient_accumulation_steps | 608 | loss = loss / gradient_accumulation_steps |
602 | 609 | ||
603 | cur_loss_val.update(loss.item(), bsz) | 610 | cur_loss_val.update(loss.item(), bsz) |
604 | cur_acc_val.update(acc.item(), bsz) | 611 | cur_acc_val.update(acc.item(), bsz) |
@@ -684,7 +691,8 @@ def train( | |||
684 | global_step_offset: int = 0, | 691 | global_step_offset: int = 0, |
685 | guidance_scale: float = 0.0, | 692 | guidance_scale: float = 0.0, |
686 | prior_loss_weight: float = 1.0, | 693 | prior_loss_weight: float = 1.0, |
687 | offset_noise_strength: float = 0.15, | 694 | offset_noise_strength: float = 0.01, |
695 | input_pertubation: float = 0.1, | ||
688 | disc: Optional[ConvNeXtDiscriminator] = None, | 696 | disc: Optional[ConvNeXtDiscriminator] = None, |
689 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
690 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
@@ -704,7 +712,7 @@ def train( | |||
704 | 712 | ||
705 | if compile_unet: | 713 | if compile_unet: |
706 | unet = torch.compile(unet, backend='hidet') | 714 | unet = torch.compile(unet, backend='hidet') |
707 | # unet = torch.compile(unet) | 715 | # unet = torch.compile(unet, mode="reduce-overhead") |
708 | 716 | ||
709 | callbacks = strategy.callbacks( | 717 | callbacks = strategy.callbacks( |
710 | accelerator=accelerator, | 718 | accelerator=accelerator, |
@@ -727,6 +735,7 @@ def train( | |||
727 | prior_loss_weight, | 735 | prior_loss_weight, |
728 | seed, | 736 | seed, |
729 | offset_noise_strength, | 737 | offset_noise_strength, |
738 | input_pertubation, | ||
730 | disc, | 739 | disc, |
731 | min_snr_gamma, | 740 | min_snr_gamma, |
732 | ) | 741 | ) |