diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-14 11:02:41 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-14 11:02:41 +0100 | 
| commit | 7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd (patch) | |
| tree | 03e2f9517e3b05be76d3c04ef6165b23c47195f8 /training | |
| parent | Better noise generation during training: https://www.crosslabs.org/blog/diffu... (diff) | |
| download | textual-inversion-diff-7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd.tar.gz textual-inversion-diff-7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd.tar.bz2 textual-inversion-diff-7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd.zip | |
Made low-freq noise configurable
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 17 | 
1 files changed, 11 insertions, 6 deletions
| diff --git a/training/functional.py b/training/functional.py index 62b8260..a9c7a8a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -256,6 +256,7 @@ def loss_step( | |||
| 256 | text_encoder: CLIPTextModel, | 256 | text_encoder: CLIPTextModel, | 
| 257 | with_prior_preservation: bool, | 257 | with_prior_preservation: bool, | 
| 258 | prior_loss_weight: float, | 258 | prior_loss_weight: float, | 
| 259 | low_freq_noise: float, | ||
| 259 | seed: int, | 260 | seed: int, | 
| 260 | step: int, | 261 | step: int, | 
| 261 | batch: dict[str, Any], | 262 | batch: dict[str, Any], | 
| @@ -274,13 +275,15 @@ def loss_step( | |||
| 274 | layout=latents.layout, | 275 | layout=latents.layout, | 
| 275 | device=latents.device, | 276 | device=latents.device, | 
| 276 | generator=generator | 277 | generator=generator | 
| 277 | ) + 0.1 * torch.randn( | ||
| 278 | latents.shape[0], latents.shape[1], 1, 1, | ||
| 279 | dtype=latents.dtype, | ||
| 280 | layout=latents.layout, | ||
| 281 | device=latents.device, | ||
| 282 | generator=generator | ||
| 283 | ) | 278 | ) | 
| 279 | if low_freq_noise > 0: | ||
| 280 | noise += low_freq_noise * torch.randn( | ||
| 281 | latents.shape[0], latents.shape[1], 1, 1, | ||
| 282 | dtype=latents.dtype, | ||
| 283 | layout=latents.layout, | ||
| 284 | device=latents.device, | ||
| 285 | generator=generator | ||
| 286 | ) | ||
| 284 | bsz = latents.shape[0] | 287 | bsz = latents.shape[0] | 
| 285 | # Sample a random timestep for each image | 288 | # Sample a random timestep for each image | 
| 286 | timesteps = torch.randint( | 289 | timesteps = torch.randint( | 
| @@ -553,6 +556,7 @@ def train( | |||
| 553 | global_step_offset: int = 0, | 556 | global_step_offset: int = 0, | 
| 554 | with_prior_preservation: bool = False, | 557 | with_prior_preservation: bool = False, | 
| 555 | prior_loss_weight: float = 1.0, | 558 | prior_loss_weight: float = 1.0, | 
| 559 | low_freq_noise: float = 0.05, | ||
| 556 | **kwargs, | 560 | **kwargs, | 
| 557 | ): | 561 | ): | 
| 558 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 
| @@ -587,6 +591,7 @@ def train( | |||
| 587 | text_encoder, | 591 | text_encoder, | 
| 588 | with_prior_preservation, | 592 | with_prior_preservation, | 
| 589 | prior_loss_weight, | 593 | prior_loss_weight, | 
| 594 | low_freq_noise, | ||
| 590 | seed, | 595 | seed, | 
| 591 | ) | 596 | ) | 
| 592 | 597 | ||
