From a254c9f7bf3172aff8385174d761fa8bba508db0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 6 Mar 2023 06:41:51 +0100 Subject: Update --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 +++++++++----- training/functional.py | 16 ---------------- util/noise.py | 8 ++++---- 3 files changed, 13 insertions(+), 25 deletions(-) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f27be78..f426de1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -307,10 +307,14 @@ class VlpnStableDiffusion(DiffusionPipeline): return timesteps, num_inference_steps - t_start def prepare_image(self, batch_size, width, height, dtype, device, generator=None): - noise = perlin_noise( - batch_size, 1, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device - ).expand(batch_size, 3, width, height) - return (1.4 * noise).clamp(-1, 1) + return (1.4 * perlin_noise( + (batch_size, 1, width, height), + res=1, + octaves=4, + generator=generator, + dtype=dtype, + device=device + )).clamp(-1, 1).expand(batch_size, 3, width, height) def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -390,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline): sag_scale: float = 0.75, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise", output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, diff --git a/training/functional.py b/training/functional.py index db46766..27a43c2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -254,7 +254,6 @@ def loss_step( text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, - perlin_strength: float, seed: int, step: int, batch: dict[str, Any], @@ -277,19 +276,6 @@ def loss_step( generator=generator ) - if perlin_strength != 0: - noise += perlin_strength * perlin_noise( - latents.shape[0], - latents.shape[1], - latents.shape[2], - latents.shape[3], - res=1, - octaves=4, - dtype=latents.dtype, - device=latents.device, - generator=generator - ) - # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -574,7 +560,6 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, - perlin_strength: float = 0.1, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -609,7 +594,6 @@ def train( text_encoder, with_prior_preservation, prior_loss_weight, - perlin_strength, seed, ) diff --git a/util/noise.py b/util/noise.py index 3c4f82d..e3ebdb2 100644 --- a/util/noise.py +++ b/util/noise.py @@ -48,13 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d return noise -def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): +def perlin_noise(shape: tuple[int, int, int, int], res=8, octaves=1, dtype=None, device=None, generator=None): return torch.stack([ torch.stack([ rand_perlin_2d_octaves( - (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator + (shape[2], shape[3]), (res, res), octaves, dtype=dtype, device=device, generator=generator ) - for _ in range(channels) + for _ in range(shape[1]) ]) - for _ in range(batch_size) + for _ in range(shape[0]) ]) -- cgit v1.2.3-54-g00ecf