diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/training/functional.py b/training/functional.py index ee73ab2..87bb339 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -277,7 +277,7 @@ def loss_step( | |||
| 277 | with_prior_preservation: bool, | 277 | with_prior_preservation: bool, |
| 278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
| 279 | seed: int, | 279 | seed: int, |
| 280 | perlin_strength: float, | 280 | offset_noise_strength: float, |
| 281 | step: int, | 281 | step: int, |
| 282 | batch: dict[str, Any], | 282 | batch: dict[str, Any], |
| 283 | eval: bool = False, | 283 | eval: bool = False, |
| @@ -300,11 +300,10 @@ def loss_step( | |||
| 300 | generator=generator | 300 | generator=generator |
| 301 | ) | 301 | ) |
| 302 | 302 | ||
| 303 | if perlin_strength != 0: | 303 | if offset_noise_strength != 0: |
| 304 | noise += perlin_strength * perlin_noise( | 304 | noise += offset_noise_strength * perlin_noise( |
| 305 | latents.shape, | 305 | latents.shape, |
| 306 | res=1, | 306 | res=1, |
| 307 | octaves=4, | ||
| 308 | dtype=latents.dtype, | 307 | dtype=latents.dtype, |
| 309 | device=latents.device, | 308 | device=latents.device, |
| 310 | generator=generator | 309 | generator=generator |
| @@ -610,7 +609,7 @@ def train( | |||
| 610 | global_step_offset: int = 0, | 609 | global_step_offset: int = 0, |
| 611 | with_prior_preservation: bool = False, | 610 | with_prior_preservation: bool = False, |
| 612 | prior_loss_weight: float = 1.0, | 611 | prior_loss_weight: float = 1.0, |
| 613 | perlin_strength: float = 0.1, | 612 | offset_noise_strength: float = 0.1, |
| 614 | **kwargs, | 613 | **kwargs, |
| 615 | ): | 614 | ): |
| 616 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -642,7 +641,7 @@ def train( | |||
| 642 | with_prior_preservation, | 641 | with_prior_preservation, |
| 643 | prior_loss_weight, | 642 | prior_loss_weight, |
| 644 | seed, | 643 | seed, |
| 645 | perlin_strength, | 644 | offset_noise_strength, |
| 646 | ) | 645 | ) |
| 647 | 646 | ||
| 648 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
