diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 36 |
1 files changed, 21 insertions, 15 deletions
diff --git a/training/functional.py b/training/functional.py index 87bb339..d285366 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -274,7 +274,7 @@ def loss_step( | |||
| 274 | noise_scheduler: SchedulerMixin, | 274 | noise_scheduler: SchedulerMixin, |
| 275 | unet: UNet2DConditionModel, | 275 | unet: UNet2DConditionModel, |
| 276 | text_encoder: CLIPTextModel, | 276 | text_encoder: CLIPTextModel, |
| 277 | with_prior_preservation: bool, | 277 | guidance_scale: float, |
| 278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
| 279 | seed: int, | 279 | seed: int, |
| 280 | offset_noise_strength: float, | 280 | offset_noise_strength: float, |
| @@ -283,13 +283,13 @@ def loss_step( | |||
| 283 | eval: bool = False, | 283 | eval: bool = False, |
| 284 | min_snr_gamma: int = 5, | 284 | min_snr_gamma: int = 5, |
| 285 | ): | 285 | ): |
| 286 | # Convert images to latent space | 286 | images = batch["pixel_values"] |
| 287 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 287 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None |
| 288 | latents = latents * vae.config.scaling_factor | 288 | bsz = images.shape[0] |
| 289 | |||
| 290 | bsz = latents.shape[0] | ||
| 291 | 289 | ||
| 292 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 290 | # Convert images to latent space |
| 291 | latents = vae.encode(images).latent_dist.sample(generator=generator) | ||
| 292 | latents *= vae.config.scaling_factor | ||
| 293 | 293 | ||
| 294 | # Sample noise that we'll add to the latents | 294 | # Sample noise that we'll add to the latents |
| 295 | noise = torch.randn( | 295 | noise = torch.randn( |
| @@ -301,13 +301,13 @@ def loss_step( | |||
| 301 | ) | 301 | ) |
| 302 | 302 | ||
| 303 | if offset_noise_strength != 0: | 303 | if offset_noise_strength != 0: |
| 304 | noise += offset_noise_strength * perlin_noise( | 304 | offset_noise = torch.randn( |
| 305 | latents.shape, | 305 | (latents.shape[0], latents.shape[1], 1, 1), |
| 306 | res=1, | ||
| 307 | dtype=latents.dtype, | 306 | dtype=latents.dtype, |
| 308 | device=latents.device, | 307 | device=latents.device, |
| 309 | generator=generator | 308 | generator=generator |
| 310 | ) | 309 | ).expand(noise.shape) |
| 310 | noise += offset_noise_strength * offset_noise | ||
| 311 | 311 | ||
| 312 | # Sample a random timestep for each image | 312 | # Sample a random timestep for each image |
| 313 | timesteps = torch.randint( | 313 | timesteps = torch.randint( |
| @@ -343,7 +343,13 @@ def loss_step( | |||
| 343 | else: | 343 | else: |
| 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 345 | 345 | ||
| 346 | if with_prior_preservation: | 346 | if guidance_scale != 0: |
| 347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 348 | model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) | ||
| 349 | model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) | ||
| 350 | |||
| 351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
| 352 | elif prior_loss_weight != 0: | ||
| 347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 353 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 348 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 354 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 349 | target, target_prior = torch.chunk(target, 2, dim=0) | 355 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -607,9 +613,9 @@ def train( | |||
| 607 | checkpoint_frequency: int = 50, | 613 | checkpoint_frequency: int = 50, |
| 608 | milestone_checkpoints: bool = True, | 614 | milestone_checkpoints: bool = True, |
| 609 | global_step_offset: int = 0, | 615 | global_step_offset: int = 0, |
| 610 | with_prior_preservation: bool = False, | 616 | guidance_scale: float = 0.0, |
| 611 | prior_loss_weight: float = 1.0, | 617 | prior_loss_weight: float = 1.0, |
| 612 | offset_noise_strength: float = 0.1, | 618 | offset_noise_strength: float = 0.15, |
| 613 | **kwargs, | 619 | **kwargs, |
| 614 | ): | 620 | ): |
| 615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 621 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -638,7 +644,7 @@ def train( | |||
| 638 | noise_scheduler, | 644 | noise_scheduler, |
| 639 | unet, | 645 | unet, |
| 640 | text_encoder, | 646 | text_encoder, |
| 641 | with_prior_preservation, | 647 | guidance_scale, |
| 642 | prior_loss_weight, | 648 | prior_loss_weight, |
| 643 | seed, | 649 | seed, |
| 644 | offset_noise_strength, | 650 | offset_noise_strength, |
