diff options
| -rw-r--r-- | train_ti.py | 7 | ||||
| -rw-r--r-- | training/functional.py | 15 |
2 files changed, 21 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 9bc74c1..ef71f6f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -188,6 +188,12 @@ def parse_args(): | |||
| 188 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 188 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
| 189 | ) | 189 | ) |
| 190 | parser.add_argument( | 190 | parser.add_argument( |
| 191 | "--perlin_strength", | ||
| 192 | type=float, | ||
| 193 | default=0.1, | ||
| 194 | help="Perlin offset noise strength.", | ||
| 195 | ) | ||
| 196 | parser.add_argument( | ||
| 191 | "--num_train_epochs", | 197 | "--num_train_epochs", |
| 192 | type=int, | 198 | type=int, |
| 193 | default=100 | 199 | default=100 |
| @@ -655,6 +661,7 @@ def main(): | |||
| 655 | checkpoint_frequency=args.checkpoint_frequency, | 661 | checkpoint_frequency=args.checkpoint_frequency, |
| 656 | milestone_checkpoints=not args.no_milestone_checkpoints, | 662 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 657 | global_step_offset=global_step_offset, | 663 | global_step_offset=global_step_offset, |
| 664 | perlin_strength=args.perlin_strength, | ||
| 658 | # -- | 665 | # -- |
| 659 | tokenizer=tokenizer, | 666 | tokenizer=tokenizer, |
| 660 | sample_scheduler=sample_scheduler, | 667 | sample_scheduler=sample_scheduler, |
diff --git a/training/functional.py b/training/functional.py index 015fe5e..a5b339d 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -278,10 +278,11 @@ def loss_step( | |||
| 278 | with_prior_preservation: bool, | 278 | with_prior_preservation: bool, |
| 279 | prior_loss_weight: float, | 279 | prior_loss_weight: float, |
| 280 | seed: int, | 280 | seed: int, |
| 281 | perlin_strength: float, | ||
| 281 | step: int, | 282 | step: int, |
| 282 | batch: dict[str, Any], | 283 | batch: dict[str, Any], |
| 283 | eval: bool = False, | 284 | eval: bool = False, |
| 284 | min_snr_gamma: int = 5 | 285 | min_snr_gamma: int = 5, |
| 285 | ): | 286 | ): |
| 286 | # Convert images to latent space | 287 | # Convert images to latent space |
| 287 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 288 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| @@ -300,6 +301,16 @@ def loss_step( | |||
| 300 | generator=generator | 301 | generator=generator |
| 301 | ) | 302 | ) |
| 302 | 303 | ||
| 304 | if perlin_strength != 0: | ||
| 305 | noise += perlin_strength * perlin_noise( | ||
| 306 | latents.shape, | ||
| 307 | res=1, | ||
| 308 | octaves=4, | ||
| 309 | dtype=latents.dtype, | ||
| 310 | device=latents.device, | ||
| 311 | generator=generator | ||
| 312 | ) | ||
| 313 | |||
| 303 | # Sample a random timestep for each image | 314 | # Sample a random timestep for each image |
| 304 | timesteps = torch.randint( | 315 | timesteps = torch.randint( |
| 305 | 0, | 316 | 0, |
| @@ -600,6 +611,7 @@ def train( | |||
| 600 | global_step_offset: int = 0, | 611 | global_step_offset: int = 0, |
| 601 | with_prior_preservation: bool = False, | 612 | with_prior_preservation: bool = False, |
| 602 | prior_loss_weight: float = 1.0, | 613 | prior_loss_weight: float = 1.0, |
| 614 | perlin_strength: float = 0.1, | ||
| 603 | **kwargs, | 615 | **kwargs, |
| 604 | ): | 616 | ): |
| 605 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 617 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -635,6 +647,7 @@ def train( | |||
| 635 | with_prior_preservation, | 647 | with_prior_preservation, |
| 636 | prior_loss_weight, | 648 | prior_loss_weight, |
| 637 | seed, | 649 | seed, |
| 650 | perlin_strength, | ||
| 638 | ) | 651 | ) |
| 639 | 652 | ||
| 640 | if accelerator.is_main_process: | 653 | if accelerator.is_main_process: |
