summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py14
-rw-r--r--training/functional.py16
-rw-r--r--util/noise.py8
3 files changed, 13 insertions, 25 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index f27be78..f426de1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -307,10 +307,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
307 return timesteps, num_inference_steps - t_start 307 return timesteps, num_inference_steps - t_start
308 308
309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None): 309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None):
310 noise = perlin_noise( 310 return (1.4 * perlin_noise(
311 batch_size, 1, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device 311 (batch_size, 1, width, height),
312 ).expand(batch_size, 3, width, height) 312 res=1,
313 return (1.4 * noise).clamp(-1, 1) 313 octaves=4,
314 generator=generator,
315 dtype=dtype,
316 device=device
317 )).clamp(-1, 1).expand(batch_size, 3, width, height)
314 318
315 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): 319 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None):
316 init_image = init_image.to(device=device, dtype=dtype) 320 init_image = init_image.to(device=device, dtype=dtype)
@@ -390,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
390 sag_scale: float = 0.75, 394 sag_scale: float = 0.75,
391 eta: float = 0.0, 395 eta: float = 0.0,
392 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 396 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
393 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, 397 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise",
394 output_type: str = "pil", 398 output_type: str = "pil",
395 return_dict: bool = True, 399 return_dict: bool = True,
396 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 400 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
diff --git a/training/functional.py b/training/functional.py
index db46766..27a43c2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -254,7 +254,6 @@ def loss_step(
254 text_encoder: CLIPTextModel, 254 text_encoder: CLIPTextModel,
255 with_prior_preservation: bool, 255 with_prior_preservation: bool,
256 prior_loss_weight: float, 256 prior_loss_weight: float,
257 perlin_strength: float,
258 seed: int, 257 seed: int,
259 step: int, 258 step: int,
260 batch: dict[str, Any], 259 batch: dict[str, Any],
@@ -277,19 +276,6 @@ def loss_step(
277 generator=generator 276 generator=generator
278 ) 277 )
279 278
280 if perlin_strength != 0:
281 noise += perlin_strength * perlin_noise(
282 latents.shape[0],
283 latents.shape[1],
284 latents.shape[2],
285 latents.shape[3],
286 res=1,
287 octaves=4,
288 dtype=latents.dtype,
289 device=latents.device,
290 generator=generator
291 )
292
293 # Sample a random timestep for each image 279 # Sample a random timestep for each image
294 timesteps = torch.randint( 280 timesteps = torch.randint(
295 0, 281 0,
@@ -574,7 +560,6 @@ def train(
574 global_step_offset: int = 0, 560 global_step_offset: int = 0,
575 with_prior_preservation: bool = False, 561 with_prior_preservation: bool = False,
576 prior_loss_weight: float = 1.0, 562 prior_loss_weight: float = 1.0,
577 perlin_strength: float = 0.1,
578 **kwargs, 563 **kwargs,
579): 564):
580 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 565 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -609,7 +594,6 @@ def train(
609 text_encoder, 594 text_encoder,
610 with_prior_preservation, 595 with_prior_preservation,
611 prior_loss_weight, 596 prior_loss_weight,
612 perlin_strength,
613 seed, 597 seed,
614 ) 598 )
615 599
diff --git a/util/noise.py b/util/noise.py
index 3c4f82d..e3ebdb2 100644
--- a/util/noise.py
+++ b/util/noise.py
@@ -48,13 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d
48 return noise 48 return noise
49 49
50 50
51def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): 51def perlin_noise(shape: tuple[int, int, int, int], res=8, octaves=1, dtype=None, device=None, generator=None):
52 return torch.stack([ 52 return torch.stack([
53 torch.stack([ 53 torch.stack([
54 rand_perlin_2d_octaves( 54 rand_perlin_2d_octaves(
55 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator 55 (shape[2], shape[3]), (res, res), octaves, dtype=dtype, device=device, generator=generator
56 ) 56 )
57 for _ in range(channels) 57 for _ in range(shape[1])
58 ]) 58 ])
59 for _ in range(batch_size) 59 for _ in range(shape[0])
60 ]) 60 ])