summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py
index ee73ab2..87bb339 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -277,7 +277,7 @@ def loss_step(
277 with_prior_preservation: bool, 277 with_prior_preservation: bool,
278 prior_loss_weight: float, 278 prior_loss_weight: float,
279 seed: int, 279 seed: int,
280 perlin_strength: float, 280 offset_noise_strength: float,
281 step: int, 281 step: int,
282 batch: dict[str, Any], 282 batch: dict[str, Any],
283 eval: bool = False, 283 eval: bool = False,
@@ -300,11 +300,10 @@ def loss_step(
300 generator=generator 300 generator=generator
301 ) 301 )
302 302
303 if perlin_strength != 0: 303 if offset_noise_strength != 0:
304 noise += perlin_strength * perlin_noise( 304 noise += offset_noise_strength * perlin_noise(
305 latents.shape, 305 latents.shape,
306 res=1, 306 res=1,
307 octaves=4,
308 dtype=latents.dtype, 307 dtype=latents.dtype,
309 device=latents.device, 308 device=latents.device,
310 generator=generator 309 generator=generator
@@ -610,7 +609,7 @@ def train(
610 global_step_offset: int = 0, 609 global_step_offset: int = 0,
611 with_prior_preservation: bool = False, 610 with_prior_preservation: bool = False,
612 prior_loss_weight: float = 1.0, 611 prior_loss_weight: float = 1.0,
613 perlin_strength: float = 0.1, 612 offset_noise_strength: float = 0.1,
614 **kwargs, 613 **kwargs,
615): 614):
616 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 615 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -642,7 +641,7 @@ def train(
642 with_prior_preservation, 641 with_prior_preservation,
643 prior_loss_weight, 642 prior_loss_weight,
644 seed, 643 seed,
645 perlin_strength, 644 offset_noise_strength,
646 ) 645 )
647 646
648 if accelerator.is_main_process: 647 if accelerator.is_main_process: