diff options
author | Volpeon <git@volpeon.ink> | 2023-03-24 17:23:09 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-24 17:23:09 +0100 |
commit | 9bfb4a078f63a7ce6e35e89093f17febd9ff4b51 (patch) | |
tree | 41b83780c79803531c7208a72bff9206ffa908da /training | |
parent | Fixed Lora training perf issue (diff) | |
download | textual-inversion-diff-9bfb4a078f63a7ce6e35e89093f17febd9ff4b51.tar.gz textual-inversion-diff-9bfb4a078f63a7ce6e35e89093f17febd9ff4b51.tar.bz2 textual-inversion-diff-9bfb4a078f63a7ce6e35e89093f17febd9ff4b51.zip |
Update
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index ee73ab2..87bb339 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -277,7 +277,7 @@ def loss_step( | |||
277 | with_prior_preservation: bool, | 277 | with_prior_preservation: bool, |
278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
279 | seed: int, | 279 | seed: int, |
280 | perlin_strength: float, | 280 | offset_noise_strength: float, |
281 | step: int, | 281 | step: int, |
282 | batch: dict[str, Any], | 282 | batch: dict[str, Any], |
283 | eval: bool = False, | 283 | eval: bool = False, |
@@ -300,11 +300,10 @@ def loss_step( | |||
300 | generator=generator | 300 | generator=generator |
301 | ) | 301 | ) |
302 | 302 | ||
303 | if perlin_strength != 0: | 303 | if offset_noise_strength != 0: |
304 | noise += perlin_strength * perlin_noise( | 304 | noise += offset_noise_strength * perlin_noise( |
305 | latents.shape, | 305 | latents.shape, |
306 | res=1, | 306 | res=1, |
307 | octaves=4, | ||
308 | dtype=latents.dtype, | 307 | dtype=latents.dtype, |
309 | device=latents.device, | 308 | device=latents.device, |
310 | generator=generator | 309 | generator=generator |
@@ -610,7 +609,7 @@ def train( | |||
610 | global_step_offset: int = 0, | 609 | global_step_offset: int = 0, |
611 | with_prior_preservation: bool = False, | 610 | with_prior_preservation: bool = False, |
612 | prior_loss_weight: float = 1.0, | 611 | prior_loss_weight: float = 1.0, |
613 | perlin_strength: float = 0.1, | 612 | offset_noise_strength: float = 0.1, |
614 | **kwargs, | 613 | **kwargs, |
615 | ): | 614 | ): |
616 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -642,7 +641,7 @@ def train( | |||
642 | with_prior_preservation, | 641 | with_prior_preservation, |
643 | prior_loss_weight, | 642 | prior_loss_weight, |
644 | seed, | 643 | seed, |
645 | perlin_strength, | 644 | offset_noise_strength, |
646 | ) | 645 | ) |
647 | 646 | ||
648 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |