diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 16 |
1 files changed, 0 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py index db46766..27a43c2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -254,7 +254,6 @@ def loss_step( | |||
| 254 | text_encoder: CLIPTextModel, | 254 | text_encoder: CLIPTextModel, |
| 255 | with_prior_preservation: bool, | 255 | with_prior_preservation: bool, |
| 256 | prior_loss_weight: float, | 256 | prior_loss_weight: float, |
| 257 | perlin_strength: float, | ||
| 258 | seed: int, | 257 | seed: int, |
| 259 | step: int, | 258 | step: int, |
| 260 | batch: dict[str, Any], | 259 | batch: dict[str, Any], |
| @@ -277,19 +276,6 @@ def loss_step( | |||
| 277 | generator=generator | 276 | generator=generator |
| 278 | ) | 277 | ) |
| 279 | 278 | ||
| 280 | if perlin_strength != 0: | ||
| 281 | noise += perlin_strength * perlin_noise( | ||
| 282 | latents.shape[0], | ||
| 283 | latents.shape[1], | ||
| 284 | latents.shape[2], | ||
| 285 | latents.shape[3], | ||
| 286 | res=1, | ||
| 287 | octaves=4, | ||
| 288 | dtype=latents.dtype, | ||
| 289 | device=latents.device, | ||
| 290 | generator=generator | ||
| 291 | ) | ||
| 292 | |||
| 293 | # Sample a random timestep for each image | 279 | # Sample a random timestep for each image |
| 294 | timesteps = torch.randint( | 280 | timesteps = torch.randint( |
| 295 | 0, | 281 | 0, |
| @@ -574,7 +560,6 @@ def train( | |||
| 574 | global_step_offset: int = 0, | 560 | global_step_offset: int = 0, |
| 575 | with_prior_preservation: bool = False, | 561 | with_prior_preservation: bool = False, |
| 576 | prior_loss_weight: float = 1.0, | 562 | prior_loss_weight: float = 1.0, |
| 577 | perlin_strength: float = 0.1, | ||
| 578 | **kwargs, | 563 | **kwargs, |
| 579 | ): | 564 | ): |
| 580 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 565 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -609,7 +594,6 @@ def train( | |||
| 609 | text_encoder, | 594 | text_encoder, |
| 610 | with_prior_preservation, | 595 | with_prior_preservation, |
| 611 | prior_loss_weight, | 596 | prior_loss_weight, |
| 612 | perlin_strength, | ||
| 613 | seed, | 597 | seed, |
| 614 | ) | 598 | ) |
| 615 | 599 | ||
