diff options
| -rw-r--r-- | infer.py | 2 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 23 | ||||
| -rw-r--r-- | train_dreambooth.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 2 | ||||
| -rw-r--r-- | util/files.py (renamed from util.py) | 0 | ||||
| -rw-r--r-- | util/noise.py | 57 |
7 files changed, 70 insertions, 18 deletions
| @@ -32,7 +32,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
| 32 | from models.clip.embeddings import patch_managed_embeddings | 32 | from models.clip.embeddings import patch_managed_embeddings |
| 33 | from models.clip.tokenizer import MultiCLIPTokenizer | 33 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 34 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 34 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 35 | from util import load_config, load_embeddings_from_dir | 35 | from util.files import load_config, load_embeddings_from_dir |
| 36 | 36 | ||
| 37 | 37 | ||
| 38 | torch.backends.cuda.matmul.allow_tf32 = True | 38 | torch.backends.cuda.matmul.allow_tf32 = True |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2251848..a6b31d8 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -24,7 +24,9 @@ from diffusers import ( | |||
| 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 25 | from diffusers.utils import logging, randn_tensor | 25 | from diffusers.utils import logging, randn_tensor |
| 26 | from transformers import CLIPTextModel, CLIPTokenizer | 26 | from transformers import CLIPTextModel, CLIPTokenizer |
| 27 | |||
| 27 | from models.clip.util import unify_input_ids, get_extended_embeddings | 28 | from models.clip.util import unify_input_ids, get_extended_embeddings |
| 29 | from util.noise import perlin_noise | ||
| 28 | 30 | ||
| 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 30 | 32 | ||
| @@ -304,23 +306,18 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 304 | 306 | ||
| 305 | return timesteps, num_inference_steps - t_start | 307 | return timesteps, num_inference_steps - t_start |
| 306 | 308 | ||
| 307 | def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): | 309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): |
| 308 | offset = (max_offset * (2 * torch.rand( | 310 | max = 0.4 |
| 311 | offset = max * (2 * torch.rand( | ||
| 309 | (batch_size, 1, 1, 1), | 312 | (batch_size, 1, 1, 1), |
| 310 | dtype=dtype, | 313 | dtype=dtype, |
| 311 | device=device, | 314 | device=device, |
| 312 | generator=generator | 315 | generator=generator |
| 313 | ) - 1)).expand(batch_size, 1, 2, 2) | 316 | ) - 1) |
| 314 | image = F.interpolate( | 317 | noise = perlin_noise( |
| 315 | torch.normal( | 318 | batch_size, width, height, res=3, octaves=3, generator=generator, dtype=dtype, device=device |
| 316 | mean=offset, | ||
| 317 | std=0.3, | ||
| 318 | generator=generator | ||
| 319 | ).clamp(-1, 1), | ||
| 320 | size=(width, height), | ||
| 321 | mode="bicubic" | ||
| 322 | ).expand(batch_size, 3, width, height) | 319 | ).expand(batch_size, 3, width, height) |
| 323 | return image | 320 | return ((1 + max) * noise + max * offset).clamp(-1, 1) |
| 324 | 321 | ||
| 325 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 322 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): |
| 326 | init_image = init_image.to(device=device, dtype=dtype) | 323 | init_image = init_image.to(device=device, dtype=dtype) |
| @@ -384,7 +381,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 384 | eta: float = 0.0, | 381 | eta: float = 0.0, |
| 385 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 382 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 386 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 383 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 387 | max_init_offset: float = 0.7, | ||
| 388 | output_type: str = "pil", | 384 | output_type: str = "pil", |
| 389 | return_dict: bool = True, | 385 | return_dict: bool = True, |
| 390 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 386 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| @@ -474,7 +470,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 474 | batch_size * num_images_per_prompt, | 470 | batch_size * num_images_per_prompt, |
| 475 | width, | 471 | width, |
| 476 | height, | 472 | height, |
| 477 | max_init_offset, | ||
| 478 | prompt_embeds.dtype, | 473 | prompt_embeds.dtype, |
| 479 | device, | 474 | device, |
| 480 | generator | 475 | generator |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 9b91172..dd2bf6e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | 15 | ||
| 16 | from util import load_config, load_embeddings_from_dir | 16 | from util.files import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
| 18 | from training.functional import train, get_models | 18 | from training.functional import train, get_models |
| 19 | from training.lr import plot_metrics | 19 | from training.lr import plot_metrics |
diff --git a/train_lora.py b/train_lora.py index e213e3d..6e72376 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -15,7 +15,7 @@ from slugify import slugify | |||
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
| 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor |
| 17 | 17 | ||
| 18 | from util import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| 20 | from training.functional import train, get_models | 20 | from training.functional import train, get_models |
| 21 | from training.lr import plot_metrics | 21 | from training.lr import plot_metrics |
diff --git a/train_ti.py b/train_ti.py index c139cc0..b9d6e56 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -12,7 +12,7 @@ from accelerate.logging import get_logger | |||
| 12 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | 14 | ||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util.files import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
| 18 | from training.lr import plot_metrics | 18 | from training.lr import plot_metrics |
diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py | |||
| @@ -0,0 +1,57 @@ | |||
| 1 | import math | ||
| 2 | import torch | ||
| 3 | |||
| 4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | ||
| 5 | |||
| 6 | |||
| 7 | def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): | ||
| 8 | delta = (res[0] / shape[0], res[1] / shape[1]) | ||
| 9 | d = (shape[0] // res[0], shape[1] // res[1]) | ||
| 10 | |||
| 11 | grid = torch.stack(torch.meshgrid( | ||
| 12 | torch.arange(0, res[0], delta[0], dtype=dtype, device=device), | ||
| 13 | torch.arange(0, res[1], delta[1], dtype=dtype, device=device), | ||
| 14 | indexing='ij' | ||
| 15 | ), dim=-1) % 1 | ||
| 16 | angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) | ||
| 17 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) | ||
| 18 | |||
| 19 | def tile_grads(slice1, slice2): return gradients[ | ||
| 20 | slice1[0]:slice1[1], | ||
| 21 | slice2[0]:slice2[1] | ||
| 22 | ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) | ||
| 23 | |||
| 24 | def dot(grad, shift): return (torch.stack(( | ||
| 25 | grid[:shape[0], :shape[1], 0] + shift[0], | ||
| 26 | grid[:shape[0], :shape[1], 1] + shift[1] | ||
| 27 | ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) | ||
| 28 | |||
| 29 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) | ||
| 30 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) | ||
| 31 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) | ||
| 32 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) | ||
| 33 | t = fade(grid[:shape[0], :shape[1]]) | ||
| 34 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) | ||
| 35 | |||
| 36 | |||
| 37 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): | ||
| 38 | noise = torch.zeros(shape, dtype=dtype, device=device) | ||
| 39 | frequency = 1 | ||
| 40 | amplitude = 1 | ||
| 41 | for _ in range(int(octaves)): | ||
| 42 | noise += amplitude * rand_perlin_2d( | ||
| 43 | shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator | ||
| 44 | ) | ||
| 45 | frequency *= 2 | ||
| 46 | amplitude *= persistence | ||
| 47 | return noise | ||
| 48 | |||
| 49 | |||
| 50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | ||
| 51 | return torch.stack([ | ||
| 52 | rand_perlin_2d_octaves( | ||
| 53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | ||
| 54 | ).unsqueeze(0) | ||
| 55 | for _ | ||
| 56 | in range(batch_size) | ||
| 57 | ]) | ||
