summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-14 11:02:41 +0100
committerVolpeon <git@volpeon.ink>2023-02-14 11:02:41 +0100
commit7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd (patch)
tree03e2f9517e3b05be76d3c04ef6165b23c47195f8 /training
parentBetter noise generation during training: https://www.crosslabs.org/blog/diffu... (diff)
downloadtextual-inversion-diff-7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd.tar.gz
textual-inversion-diff-7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd.tar.bz2
textual-inversion-diff-7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd.zip
Made low-freq noise configurable
Diffstat (limited to 'training')
-rw-r--r--training/functional.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index 62b8260..a9c7a8a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -256,6 +256,7 @@ def loss_step(
256 text_encoder: CLIPTextModel, 256 text_encoder: CLIPTextModel,
257 with_prior_preservation: bool, 257 with_prior_preservation: bool,
258 prior_loss_weight: float, 258 prior_loss_weight: float,
259 low_freq_noise: float,
259 seed: int, 260 seed: int,
260 step: int, 261 step: int,
261 batch: dict[str, Any], 262 batch: dict[str, Any],
@@ -274,13 +275,15 @@ def loss_step(
274 layout=latents.layout, 275 layout=latents.layout,
275 device=latents.device, 276 device=latents.device,
276 generator=generator 277 generator=generator
277 ) + 0.1 * torch.randn(
278 latents.shape[0], latents.shape[1], 1, 1,
279 dtype=latents.dtype,
280 layout=latents.layout,
281 device=latents.device,
282 generator=generator
283 ) 278 )
279 if low_freq_noise > 0:
280 noise += low_freq_noise * torch.randn(
281 latents.shape[0], latents.shape[1], 1, 1,
282 dtype=latents.dtype,
283 layout=latents.layout,
284 device=latents.device,
285 generator=generator
286 )
284 bsz = latents.shape[0] 287 bsz = latents.shape[0]
285 # Sample a random timestep for each image 288 # Sample a random timestep for each image
286 timesteps = torch.randint( 289 timesteps = torch.randint(
@@ -553,6 +556,7 @@ def train(
553 global_step_offset: int = 0, 556 global_step_offset: int = 0,
554 with_prior_preservation: bool = False, 557 with_prior_preservation: bool = False,
555 prior_loss_weight: float = 1.0, 558 prior_loss_weight: float = 1.0,
559 low_freq_noise: float = 0.05,
556 **kwargs, 560 **kwargs,
557): 561):
558 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 562 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -587,6 +591,7 @@ def train(
587 text_encoder, 591 text_encoder,
588 with_prior_preservation, 592 with_prior_preservation,
589 prior_loss_weight, 593 prior_loss_weight,
594 low_freq_noise,
590 seed, 595 seed,
591 ) 596 )
592 597