From d9bb4a0d43276c8e120866af044fcf3566930859 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 22:15:17 +0100 Subject: Bring back Perlin offset noise --- train_ti.py | 7 +++++++ training/functional.py | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/train_ti.py b/train_ti.py index 9bc74c1..ef71f6f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -187,6 +187,12 @@ def parse_args(): default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) + parser.add_argument( + "--perlin_strength", + type=float, + default=0.1, + help="Perlin offset noise strength.", + ) parser.add_argument( "--num_train_epochs", type=int, @@ -655,6 +661,7 @@ def main(): checkpoint_frequency=args.checkpoint_frequency, milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, + perlin_strength=args.perlin_strength, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, diff --git a/training/functional.py b/training/functional.py index 015fe5e..a5b339d 100644 --- a/training/functional.py +++ b/training/functional.py @@ -278,10 +278,11 @@ def loss_step( with_prior_preservation: bool, prior_loss_weight: float, seed: int, + perlin_strength: float, step: int, batch: dict[str, Any], eval: bool = False, - min_snr_gamma: int = 5 + min_snr_gamma: int = 5, ): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() @@ -300,6 +301,16 @@ def loss_step( generator=generator ) + if perlin_strength != 0: + noise += perlin_strength * perlin_noise( + latents.shape, + res=1, + octaves=4, + dtype=latents.dtype, + device=latents.device, + generator=generator + ) + # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -600,6 +611,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, + perlin_strength: float = 0.1, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -635,6 +647,7 @@ def train( with_prior_preservation, prior_loss_weight, seed, + perlin_strength, ) if accelerator.is_main_process: -- cgit v1.2.3-70-g09d2