diff options
-rw-r--r-- | infer.py | 2 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 23 | ||||
-rw-r--r-- | train_dreambooth.py | 2 | ||||
-rw-r--r-- | train_lora.py | 2 | ||||
-rw-r--r-- | train_ti.py | 2 | ||||
-rw-r--r-- | util/files.py (renamed from util.py) | 0 | ||||
-rw-r--r-- | util/noise.py | 57 |
7 files changed, 70 insertions, 18 deletions
@@ -32,7 +32,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
32 | from models.clip.embeddings import patch_managed_embeddings | 32 | from models.clip.embeddings import patch_managed_embeddings |
33 | from models.clip.tokenizer import MultiCLIPTokenizer | 33 | from models.clip.tokenizer import MultiCLIPTokenizer |
34 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 34 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
35 | from util import load_config, load_embeddings_from_dir | 35 | from util.files import load_config, load_embeddings_from_dir |
36 | 36 | ||
37 | 37 | ||
38 | torch.backends.cuda.matmul.allow_tf32 = True | 38 | torch.backends.cuda.matmul.allow_tf32 = True |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2251848..a6b31d8 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -24,7 +24,9 @@ from diffusers import ( | |||
24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
25 | from diffusers.utils import logging, randn_tensor | 25 | from diffusers.utils import logging, randn_tensor |
26 | from transformers import CLIPTextModel, CLIPTokenizer | 26 | from transformers import CLIPTextModel, CLIPTokenizer |
27 | |||
27 | from models.clip.util import unify_input_ids, get_extended_embeddings | 28 | from models.clip.util import unify_input_ids, get_extended_embeddings |
29 | from util.noise import perlin_noise | ||
28 | 30 | ||
29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
30 | 32 | ||
@@ -304,23 +306,18 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
304 | 306 | ||
305 | return timesteps, num_inference_steps - t_start | 307 | return timesteps, num_inference_steps - t_start |
306 | 308 | ||
307 | def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): | 309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): |
308 | offset = (max_offset * (2 * torch.rand( | 310 | max = 0.4 |
311 | offset = max * (2 * torch.rand( | ||
309 | (batch_size, 1, 1, 1), | 312 | (batch_size, 1, 1, 1), |
310 | dtype=dtype, | 313 | dtype=dtype, |
311 | device=device, | 314 | device=device, |
312 | generator=generator | 315 | generator=generator |
313 | ) - 1)).expand(batch_size, 1, 2, 2) | 316 | ) - 1) |
314 | image = F.interpolate( | 317 | noise = perlin_noise( |
315 | torch.normal( | 318 | batch_size, width, height, res=3, octaves=3, generator=generator, dtype=dtype, device=device |
316 | mean=offset, | ||
317 | std=0.3, | ||
318 | generator=generator | ||
319 | ).clamp(-1, 1), | ||
320 | size=(width, height), | ||
321 | mode="bicubic" | ||
322 | ).expand(batch_size, 3, width, height) | 319 | ).expand(batch_size, 3, width, height) |
323 | return image | 320 | return ((1 + max) * noise + max * offset).clamp(-1, 1) |
324 | 321 | ||
325 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 322 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): |
326 | init_image = init_image.to(device=device, dtype=dtype) | 323 | init_image = init_image.to(device=device, dtype=dtype) |
@@ -384,7 +381,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
384 | eta: float = 0.0, | 381 | eta: float = 0.0, |
385 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 382 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
386 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 383 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
387 | max_init_offset: float = 0.7, | ||
388 | output_type: str = "pil", | 384 | output_type: str = "pil", |
389 | return_dict: bool = True, | 385 | return_dict: bool = True, |
390 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 386 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -474,7 +470,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
474 | batch_size * num_images_per_prompt, | 470 | batch_size * num_images_per_prompt, |
475 | width, | 471 | width, |
476 | height, | 472 | height, |
477 | max_init_offset, | ||
478 | prompt_embeds.dtype, | 473 | prompt_embeds.dtype, |
479 | device, | 474 | device, |
480 | generator | 475 | generator |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9b91172..dd2bf6e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from slugify import slugify | 14 | from slugify import slugify |
15 | 15 | ||
16 | from util import load_config, load_embeddings_from_dir | 16 | from util.files import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
18 | from training.functional import train, get_models | 18 | from training.functional import train, get_models |
19 | from training.lr import plot_metrics | 19 | from training.lr import plot_metrics |
diff --git a/train_lora.py b/train_lora.py index e213e3d..6e72376 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -15,7 +15,7 @@ from slugify import slugify | |||
15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor |
17 | 17 | ||
18 | from util import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 20 | from training.functional import train, get_models |
21 | from training.lr import plot_metrics | 21 | from training.lr import plot_metrics |
diff --git a/train_ti.py b/train_ti.py index c139cc0..b9d6e56 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -12,7 +12,7 @@ from accelerate.logging import get_logger | |||
12 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from util import load_config, load_embeddings_from_dir | 15 | from util.files import load_config, load_embeddings_from_dir |
16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
17 | from training.functional import train, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
18 | from training.lr import plot_metrics | 18 | from training.lr import plot_metrics |
diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py | |||
@@ -0,0 +1,57 @@ | |||
1 | import math | ||
2 | import torch | ||
3 | |||
4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | ||
5 | |||
6 | |||
7 | def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): | ||
8 | delta = (res[0] / shape[0], res[1] / shape[1]) | ||
9 | d = (shape[0] // res[0], shape[1] // res[1]) | ||
10 | |||
11 | grid = torch.stack(torch.meshgrid( | ||
12 | torch.arange(0, res[0], delta[0], dtype=dtype, device=device), | ||
13 | torch.arange(0, res[1], delta[1], dtype=dtype, device=device), | ||
14 | indexing='ij' | ||
15 | ), dim=-1) % 1 | ||
16 | angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) | ||
17 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) | ||
18 | |||
19 | def tile_grads(slice1, slice2): return gradients[ | ||
20 | slice1[0]:slice1[1], | ||
21 | slice2[0]:slice2[1] | ||
22 | ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) | ||
23 | |||
24 | def dot(grad, shift): return (torch.stack(( | ||
25 | grid[:shape[0], :shape[1], 0] + shift[0], | ||
26 | grid[:shape[0], :shape[1], 1] + shift[1] | ||
27 | ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) | ||
28 | |||
29 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) | ||
30 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) | ||
31 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) | ||
32 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) | ||
33 | t = fade(grid[:shape[0], :shape[1]]) | ||
34 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) | ||
35 | |||
36 | |||
37 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): | ||
38 | noise = torch.zeros(shape, dtype=dtype, device=device) | ||
39 | frequency = 1 | ||
40 | amplitude = 1 | ||
41 | for _ in range(int(octaves)): | ||
42 | noise += amplitude * rand_perlin_2d( | ||
43 | shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator | ||
44 | ) | ||
45 | frequency *= 2 | ||
46 | amplitude *= persistence | ||
47 | return noise | ||
48 | |||
49 | |||
50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | ||
51 | return torch.stack([ | ||
52 | rand_perlin_2d_octaves( | ||
53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | ||
54 | ).unsqueeze(0) | ||
55 | for _ | ||
56 | in range(batch_size) | ||
57 | ]) | ||