summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py7
-rw-r--r--training/functional.py15
2 files changed, 21 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 9bc74c1..ef71f6f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -188,6 +188,12 @@ def parse_args():
188 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 188 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
189 ) 189 )
190 parser.add_argument( 190 parser.add_argument(
191 "--perlin_strength",
192 type=float,
193 default=0.1,
194 help="Perlin offset noise strength.",
195 )
196 parser.add_argument(
191 "--num_train_epochs", 197 "--num_train_epochs",
192 type=int, 198 type=int,
193 default=100 199 default=100
@@ -655,6 +661,7 @@ def main():
655 checkpoint_frequency=args.checkpoint_frequency, 661 checkpoint_frequency=args.checkpoint_frequency,
656 milestone_checkpoints=not args.no_milestone_checkpoints, 662 milestone_checkpoints=not args.no_milestone_checkpoints,
657 global_step_offset=global_step_offset, 663 global_step_offset=global_step_offset,
664 perlin_strength=args.perlin_strength,
658 # -- 665 # --
659 tokenizer=tokenizer, 666 tokenizer=tokenizer,
660 sample_scheduler=sample_scheduler, 667 sample_scheduler=sample_scheduler,
diff --git a/training/functional.py b/training/functional.py
index 015fe5e..a5b339d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -278,10 +278,11 @@ def loss_step(
278 with_prior_preservation: bool, 278 with_prior_preservation: bool,
279 prior_loss_weight: float, 279 prior_loss_weight: float,
280 seed: int, 280 seed: int,
281 perlin_strength: float,
281 step: int, 282 step: int,
282 batch: dict[str, Any], 283 batch: dict[str, Any],
283 eval: bool = False, 284 eval: bool = False,
284 min_snr_gamma: int = 5 285 min_snr_gamma: int = 5,
285): 286):
286 # Convert images to latent space 287 # Convert images to latent space
287 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 288 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
@@ -300,6 +301,16 @@ def loss_step(
300 generator=generator 301 generator=generator
301 ) 302 )
302 303
304 if perlin_strength != 0:
305 noise += perlin_strength * perlin_noise(
306 latents.shape,
307 res=1,
308 octaves=4,
309 dtype=latents.dtype,
310 device=latents.device,
311 generator=generator
312 )
313
303 # Sample a random timestep for each image 314 # Sample a random timestep for each image
304 timesteps = torch.randint( 315 timesteps = torch.randint(
305 0, 316 0,
@@ -600,6 +611,7 @@ def train(
600 global_step_offset: int = 0, 611 global_step_offset: int = 0,
601 with_prior_preservation: bool = False, 612 with_prior_preservation: bool = False,
602 prior_loss_weight: float = 1.0, 613 prior_loss_weight: float = 1.0,
614 perlin_strength: float = 0.1,
603 **kwargs, 615 **kwargs,
604): 616):
605 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 617 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -635,6 +647,7 @@ def train(
635 with_prior_preservation, 647 with_prior_preservation,
636 prior_loss_weight, 648 prior_loss_weight,
637 seed, 649 seed,
650 perlin_strength,
638 ) 651 )
639 652
640 if accelerator.is_main_process: 653 if accelerator.is_main_process: