diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index 62b8260..a9c7a8a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -256,6 +256,7 @@ def loss_step( | |||
| 256 | text_encoder: CLIPTextModel, | 256 | text_encoder: CLIPTextModel, |
| 257 | with_prior_preservation: bool, | 257 | with_prior_preservation: bool, |
| 258 | prior_loss_weight: float, | 258 | prior_loss_weight: float, |
| 259 | low_freq_noise: float, | ||
| 259 | seed: int, | 260 | seed: int, |
| 260 | step: int, | 261 | step: int, |
| 261 | batch: dict[str, Any], | 262 | batch: dict[str, Any], |
| @@ -274,13 +275,15 @@ def loss_step( | |||
| 274 | layout=latents.layout, | 275 | layout=latents.layout, |
| 275 | device=latents.device, | 276 | device=latents.device, |
| 276 | generator=generator | 277 | generator=generator |
| 277 | ) + 0.1 * torch.randn( | ||
| 278 | latents.shape[0], latents.shape[1], 1, 1, | ||
| 279 | dtype=latents.dtype, | ||
| 280 | layout=latents.layout, | ||
| 281 | device=latents.device, | ||
| 282 | generator=generator | ||
| 283 | ) | 278 | ) |
| 279 | if low_freq_noise > 0: | ||
| 280 | noise += low_freq_noise * torch.randn( | ||
| 281 | latents.shape[0], latents.shape[1], 1, 1, | ||
| 282 | dtype=latents.dtype, | ||
| 283 | layout=latents.layout, | ||
| 284 | device=latents.device, | ||
| 285 | generator=generator | ||
| 286 | ) | ||
| 284 | bsz = latents.shape[0] | 287 | bsz = latents.shape[0] |
| 285 | # Sample a random timestep for each image | 288 | # Sample a random timestep for each image |
| 286 | timesteps = torch.randint( | 289 | timesteps = torch.randint( |
| @@ -553,6 +556,7 @@ def train( | |||
| 553 | global_step_offset: int = 0, | 556 | global_step_offset: int = 0, |
| 554 | with_prior_preservation: bool = False, | 557 | with_prior_preservation: bool = False, |
| 555 | prior_loss_weight: float = 1.0, | 558 | prior_loss_weight: float = 1.0, |
| 559 | low_freq_noise: float = 0.05, | ||
| 556 | **kwargs, | 560 | **kwargs, |
| 557 | ): | 561 | ): |
| 558 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -587,6 +591,7 @@ def train( | |||
| 587 | text_encoder, | 591 | text_encoder, |
| 588 | with_prior_preservation, | 592 | with_prior_preservation, |
| 589 | prior_loss_weight, | 593 | prior_loss_weight, |
| 594 | low_freq_noise, | ||
| 590 | seed, | 595 | seed, |
| 591 | ) | 596 | ) |
| 592 | 597 | ||
