diff options
author | Volpeon <git@volpeon.ink> | 2023-03-06 06:41:51 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-06 06:41:51 +0100 |
commit | a254c9f7bf3172aff8385174d761fa8bba508db0 (patch) | |
tree | ec9179f992fda32745f351a51a18e94122b34892 | |
parent | More flexible pipeline wrt init noise (diff) | |
download | textual-inversion-diff-a254c9f7bf3172aff8385174d761fa8bba508db0.tar.gz textual-inversion-diff-a254c9f7bf3172aff8385174d761fa8bba508db0.tar.bz2 textual-inversion-diff-a254c9f7bf3172aff8385174d761fa8bba508db0.zip |
Update
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 | ||||
-rw-r--r-- | training/functional.py | 16 | ||||
-rw-r--r-- | util/noise.py | 8 |
3 files changed, 13 insertions, 25 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f27be78..f426de1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -307,10 +307,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
307 | return timesteps, num_inference_steps - t_start | 307 | return timesteps, num_inference_steps - t_start |
308 | 308 | ||
309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): | 309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): |
310 | noise = perlin_noise( | 310 | return (1.4 * perlin_noise( |
311 | batch_size, 1, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device | 311 | (batch_size, 1, width, height), |
312 | ).expand(batch_size, 3, width, height) | 312 | res=1, |
313 | return (1.4 * noise).clamp(-1, 1) | 313 | octaves=4, |
314 | generator=generator, | ||
315 | dtype=dtype, | ||
316 | device=device | ||
317 | )).clamp(-1, 1).expand(batch_size, 3, width, height) | ||
314 | 318 | ||
315 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): | 319 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): |
316 | init_image = init_image.to(device=device, dtype=dtype) | 320 | init_image = init_image.to(device=device, dtype=dtype) |
@@ -390,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
390 | sag_scale: float = 0.75, | 394 | sag_scale: float = 0.75, |
391 | eta: float = 0.0, | 395 | eta: float = 0.0, |
392 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
393 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, | 397 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise", |
394 | output_type: str = "pil", | 398 | output_type: str = "pil", |
395 | return_dict: bool = True, | 399 | return_dict: bool = True, |
396 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 400 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
diff --git a/training/functional.py b/training/functional.py index db46766..27a43c2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -254,7 +254,6 @@ def loss_step( | |||
254 | text_encoder: CLIPTextModel, | 254 | text_encoder: CLIPTextModel, |
255 | with_prior_preservation: bool, | 255 | with_prior_preservation: bool, |
256 | prior_loss_weight: float, | 256 | prior_loss_weight: float, |
257 | perlin_strength: float, | ||
258 | seed: int, | 257 | seed: int, |
259 | step: int, | 258 | step: int, |
260 | batch: dict[str, Any], | 259 | batch: dict[str, Any], |
@@ -277,19 +276,6 @@ def loss_step( | |||
277 | generator=generator | 276 | generator=generator |
278 | ) | 277 | ) |
279 | 278 | ||
280 | if perlin_strength != 0: | ||
281 | noise += perlin_strength * perlin_noise( | ||
282 | latents.shape[0], | ||
283 | latents.shape[1], | ||
284 | latents.shape[2], | ||
285 | latents.shape[3], | ||
286 | res=1, | ||
287 | octaves=4, | ||
288 | dtype=latents.dtype, | ||
289 | device=latents.device, | ||
290 | generator=generator | ||
291 | ) | ||
292 | |||
293 | # Sample a random timestep for each image | 279 | # Sample a random timestep for each image |
294 | timesteps = torch.randint( | 280 | timesteps = torch.randint( |
295 | 0, | 281 | 0, |
@@ -574,7 +560,6 @@ def train( | |||
574 | global_step_offset: int = 0, | 560 | global_step_offset: int = 0, |
575 | with_prior_preservation: bool = False, | 561 | with_prior_preservation: bool = False, |
576 | prior_loss_weight: float = 1.0, | 562 | prior_loss_weight: float = 1.0, |
577 | perlin_strength: float = 0.1, | ||
578 | **kwargs, | 563 | **kwargs, |
579 | ): | 564 | ): |
580 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 565 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -609,7 +594,6 @@ def train( | |||
609 | text_encoder, | 594 | text_encoder, |
610 | with_prior_preservation, | 595 | with_prior_preservation, |
611 | prior_loss_weight, | 596 | prior_loss_weight, |
612 | perlin_strength, | ||
613 | seed, | 597 | seed, |
614 | ) | 598 | ) |
615 | 599 | ||
diff --git a/util/noise.py b/util/noise.py index 3c4f82d..e3ebdb2 100644 --- a/util/noise.py +++ b/util/noise.py | |||
@@ -48,13 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d | |||
48 | return noise | 48 | return noise |
49 | 49 | ||
50 | 50 | ||
51 | def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | 51 | def perlin_noise(shape: tuple[int, int, int, int], res=8, octaves=1, dtype=None, device=None, generator=None): |
52 | return torch.stack([ | 52 | return torch.stack([ |
53 | torch.stack([ | 53 | torch.stack([ |
54 | rand_perlin_2d_octaves( | 54 | rand_perlin_2d_octaves( |
55 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | 55 | (shape[2], shape[3]), (res, res), octaves, dtype=dtype, device=device, generator=generator |
56 | ) | 56 | ) |
57 | for _ in range(channels) | 57 | for _ in range(shape[1]) |
58 | ]) | 58 | ]) |
59 | for _ in range(batch_size) | 59 | for _ in range(shape[0]) |
60 | ]) | 60 | ]) |