summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-04 09:46:41 +0100
committerVolpeon <git@volpeon.ink>2023-03-04 09:46:41 +0100
commit73910b7f55244ce787fc6a3e6af09240ef0cdfd3 (patch)
tree3ef927578fc54b59ab6ff1bd00c3f804c0b9a7bf
parentPipeline: Improved initial image generation (diff)
downloadtextual-inversion-diff-73910b7f55244ce787fc6a3e6af09240ef0cdfd3.tar.gz
textual-inversion-diff-73910b7f55244ce787fc6a3e6af09240ef0cdfd3.tar.bz2
textual-inversion-diff-73910b7f55244ce787fc6a3e6af09240ef0cdfd3.zip
Pipeline: Perlin noise for init image
-rw-r--r--infer.py2
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py23
-rw-r--r--train_dreambooth.py2
-rw-r--r--train_lora.py2
-rw-r--r--train_ti.py2
-rw-r--r--util/files.py (renamed from util.py)0
-rw-r--r--util/noise.py57
7 files changed, 70 insertions, 18 deletions
diff --git a/infer.py b/infer.py
index 07dcd22..cf59bba 100644
--- a/infer.py
+++ b/infer.py
@@ -32,7 +32,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt
32from models.clip.embeddings import patch_managed_embeddings 32from models.clip.embeddings import patch_managed_embeddings
33from models.clip.tokenizer import MultiCLIPTokenizer 33from models.clip.tokenizer import MultiCLIPTokenizer
34from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 34from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
35from util import load_config, load_embeddings_from_dir 35from util.files import load_config, load_embeddings_from_dir
36 36
37 37
38torch.backends.cuda.matmul.allow_tf32 = True 38torch.backends.cuda.matmul.allow_tf32 = True
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 2251848..a6b31d8 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -24,7 +24,9 @@ from diffusers import (
24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
25from diffusers.utils import logging, randn_tensor 25from diffusers.utils import logging, randn_tensor
26from transformers import CLIPTextModel, CLIPTokenizer 26from transformers import CLIPTextModel, CLIPTokenizer
27
27from models.clip.util import unify_input_ids, get_extended_embeddings 28from models.clip.util import unify_input_ids, get_extended_embeddings
29from util.noise import perlin_noise
28 30
29logger = logging.get_logger(__name__) # pylint: disable=invalid-name 31logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30 32
@@ -304,23 +306,18 @@ class VlpnStableDiffusion(DiffusionPipeline):
304 306
305 return timesteps, num_inference_steps - t_start 307 return timesteps, num_inference_steps - t_start
306 308
307 def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): 309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None):
308 offset = (max_offset * (2 * torch.rand( 310 max = 0.4
311 offset = max * (2 * torch.rand(
309 (batch_size, 1, 1, 1), 312 (batch_size, 1, 1, 1),
310 dtype=dtype, 313 dtype=dtype,
311 device=device, 314 device=device,
312 generator=generator 315 generator=generator
313 ) - 1)).expand(batch_size, 1, 2, 2) 316 ) - 1)
314 image = F.interpolate( 317 noise = perlin_noise(
315 torch.normal( 318 batch_size, width, height, res=3, octaves=3, generator=generator, dtype=dtype, device=device
316 mean=offset,
317 std=0.3,
318 generator=generator
319 ).clamp(-1, 1),
320 size=(width, height),
321 mode="bicubic"
322 ).expand(batch_size, 3, width, height) 319 ).expand(batch_size, 3, width, height)
323 return image 320 return ((1 + max) * noise + max * offset).clamp(-1, 1)
324 321
325 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): 322 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None):
326 init_image = init_image.to(device=device, dtype=dtype) 323 init_image = init_image.to(device=device, dtype=dtype)
@@ -384,7 +381,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
384 eta: float = 0.0, 381 eta: float = 0.0,
385 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 382 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
386 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 383 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
387 max_init_offset: float = 0.7,
388 output_type: str = "pil", 384 output_type: str = "pil",
389 return_dict: bool = True, 385 return_dict: bool = True,
390 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 386 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -474,7 +470,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
474 batch_size * num_images_per_prompt, 470 batch_size * num_images_per_prompt,
475 width, 471 width,
476 height, 472 height,
477 max_init_offset,
478 prompt_embeds.dtype, 473 prompt_embeds.dtype,
479 device, 474 device,
480 generator 475 generator
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9b91172..dd2bf6e 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15 15
16from util import load_config, load_embeddings_from_dir 16from util.files import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, keyword_filter 17from data.csv import VlpnDataModule, keyword_filter
18from training.functional import train, get_models 18from training.functional import train, get_models
19from training.lr import plot_metrics 19from training.lr import plot_metrics
diff --git a/train_lora.py b/train_lora.py
index e213e3d..6e72376 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -15,7 +15,7 @@ from slugify import slugify
15from diffusers.loaders import AttnProcsLayers 15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor 16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor
17 17
18from util import load_config, load_embeddings_from_dir 18from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models 20from training.functional import train, get_models
21from training.lr import plot_metrics 21from training.lr import plot_metrics
diff --git a/train_ti.py b/train_ti.py
index c139cc0..b9d6e56 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -12,7 +12,7 @@ from accelerate.logging import get_logger
12from accelerate.utils import LoggerType, set_seed 12from accelerate.utils import LoggerType, set_seed
13from slugify import slugify 13from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util.files import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, add_placeholder_tokens, get_models 17from training.functional import train, add_placeholder_tokens, get_models
18from training.lr import plot_metrics 18from training.lr import plot_metrics
diff --git a/util.py b/util/files.py
index 2712525..2712525 100644
--- a/util.py
+++ b/util/files.py
diff --git a/util/noise.py b/util/noise.py
new file mode 100644
index 0000000..38ab172
--- /dev/null
+++ b/util/noise.py
@@ -0,0 +1,57 @@
1import math
2import torch
3
4# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
5
6
7def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None):
8 delta = (res[0] / shape[0], res[1] / shape[1])
9 d = (shape[0] // res[0], shape[1] // res[1])
10
11 grid = torch.stack(torch.meshgrid(
12 torch.arange(0, res[0], delta[0], dtype=dtype, device=device),
13 torch.arange(0, res[1], delta[1], dtype=dtype, device=device),
14 indexing='ij'
15 ), dim=-1) % 1
16 angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device)
17 gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
18
19 def tile_grads(slice1, slice2): return gradients[
20 slice1[0]:slice1[1],
21 slice2[0]:slice2[1]
22 ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
23
24 def dot(grad, shift): return (torch.stack((
25 grid[:shape[0], :shape[1], 0] + shift[0],
26 grid[:shape[0], :shape[1], 1] + shift[1]
27 ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
28
29 n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
30 n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
31 n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
32 n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
33 t = fade(grid[:shape[0], :shape[1]])
34 return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
35
36
37def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None):
38 noise = torch.zeros(shape, dtype=dtype, device=device)
39 frequency = 1
40 amplitude = 1
41 for _ in range(int(octaves)):
42 noise += amplitude * rand_perlin_2d(
43 shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator
44 )
45 frequency *= 2
46 amplitude *= persistence
47 return noise
48
49
50def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None):
51 return torch.stack([
52 rand_perlin_2d_octaves(
53 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator
54 ).unsqueeze(0)
55 for _
56 in range(batch_size)
57 ])