summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-23 22:15:17 +0100
committerVolpeon <git@volpeon.ink>2023-03-23 22:15:17 +0100
commitd9bb4a0d43276c8e120866af044fcf3566930859 (patch)
tree6b0abba5270d02087d3f49d92201b90133882330 /training
parentUpdate (diff)
downloadtextual-inversion-diff-d9bb4a0d43276c8e120866af044fcf3566930859.tar.gz
textual-inversion-diff-d9bb4a0d43276c8e120866af044fcf3566930859.tar.bz2
textual-inversion-diff-d9bb4a0d43276c8e120866af044fcf3566930859.zip
Bring back Perlin offset noise
Diffstat (limited to 'training')
-rw-r--r--training/functional.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index 015fe5e..a5b339d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -278,10 +278,11 @@ def loss_step(
278 with_prior_preservation: bool, 278 with_prior_preservation: bool,
279 prior_loss_weight: float, 279 prior_loss_weight: float,
280 seed: int, 280 seed: int,
281 perlin_strength: float,
281 step: int, 282 step: int,
282 batch: dict[str, Any], 283 batch: dict[str, Any],
283 eval: bool = False, 284 eval: bool = False,
284 min_snr_gamma: int = 5 285 min_snr_gamma: int = 5,
285): 286):
286 # Convert images to latent space 287 # Convert images to latent space
287 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 288 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
@@ -300,6 +301,16 @@ def loss_step(
300 generator=generator 301 generator=generator
301 ) 302 )
302 303
304 if perlin_strength != 0:
305 noise += perlin_strength * perlin_noise(
306 latents.shape,
307 res=1,
308 octaves=4,
309 dtype=latents.dtype,
310 device=latents.device,
311 generator=generator
312 )
313
303 # Sample a random timestep for each image 314 # Sample a random timestep for each image
304 timesteps = torch.randint( 315 timesteps = torch.randint(
305 0, 316 0,
@@ -600,6 +611,7 @@ def train(
600 global_step_offset: int = 0, 611 global_step_offset: int = 0,
601 with_prior_preservation: bool = False, 612 with_prior_preservation: bool = False,
602 prior_loss_weight: float = 1.0, 613 prior_loss_weight: float = 1.0,
614 perlin_strength: float = 0.1,
603 **kwargs, 615 **kwargs,
604): 616):
605 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 617 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -635,6 +647,7 @@ def train(
635 with_prior_preservation, 647 with_prior_preservation,
636 prior_loss_weight, 648 prior_loss_weight,
637 seed, 649 seed,
650 perlin_strength,
638 ) 651 )
639 652
640 if accelerator.is_main_process: 653 if accelerator.is_main_process: