diff options
-rw-r--r-- | training/functional.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index 62b8260..a9c7a8a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -256,6 +256,7 @@ def loss_step( | |||
256 | text_encoder: CLIPTextModel, | 256 | text_encoder: CLIPTextModel, |
257 | with_prior_preservation: bool, | 257 | with_prior_preservation: bool, |
258 | prior_loss_weight: float, | 258 | prior_loss_weight: float, |
259 | low_freq_noise: float, | ||
259 | seed: int, | 260 | seed: int, |
260 | step: int, | 261 | step: int, |
261 | batch: dict[str, Any], | 262 | batch: dict[str, Any], |
@@ -274,13 +275,15 @@ def loss_step( | |||
274 | layout=latents.layout, | 275 | layout=latents.layout, |
275 | device=latents.device, | 276 | device=latents.device, |
276 | generator=generator | 277 | generator=generator |
277 | ) + 0.1 * torch.randn( | ||
278 | latents.shape[0], latents.shape[1], 1, 1, | ||
279 | dtype=latents.dtype, | ||
280 | layout=latents.layout, | ||
281 | device=latents.device, | ||
282 | generator=generator | ||
283 | ) | 278 | ) |
279 | if low_freq_noise > 0: | ||
280 | noise += low_freq_noise * torch.randn( | ||
281 | latents.shape[0], latents.shape[1], 1, 1, | ||
282 | dtype=latents.dtype, | ||
283 | layout=latents.layout, | ||
284 | device=latents.device, | ||
285 | generator=generator | ||
286 | ) | ||
284 | bsz = latents.shape[0] | 287 | bsz = latents.shape[0] |
285 | # Sample a random timestep for each image | 288 | # Sample a random timestep for each image |
286 | timesteps = torch.randint( | 289 | timesteps = torch.randint( |
@@ -553,6 +556,7 @@ def train( | |||
553 | global_step_offset: int = 0, | 556 | global_step_offset: int = 0, |
554 | with_prior_preservation: bool = False, | 557 | with_prior_preservation: bool = False, |
555 | prior_loss_weight: float = 1.0, | 558 | prior_loss_weight: float = 1.0, |
559 | low_freq_noise: float = 0.05, | ||
556 | **kwargs, | 560 | **kwargs, |
557 | ): | 561 | ): |
558 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -587,6 +591,7 @@ def train( | |||
587 | text_encoder, | 591 | text_encoder, |
588 | with_prior_preservation, | 592 | with_prior_preservation, |
589 | prior_loss_weight, | 593 | prior_loss_weight, |
594 | low_freq_noise, | ||
590 | seed, | 595 | seed, |
591 | ) | 596 | ) |
592 | 597 | ||