From 73910b7f55244ce787fc6a3e6af09240ef0cdfd3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 4 Mar 2023 09:46:41 +0100 Subject: Pipeline: Perlin noise for init image --- infer.py | 2 +- .../stable_diffusion/vlpn_stable_diffusion.py | 23 ++++----- train_dreambooth.py | 2 +- train_lora.py | 2 +- train_ti.py | 2 +- util.py | 45 ----------------- util/files.py | 45 +++++++++++++++++ util/noise.py | 57 ++++++++++++++++++++++ 8 files changed, 115 insertions(+), 63 deletions(-) delete mode 100644 util.py create mode 100644 util/files.py create mode 100644 util/noise.py diff --git a/infer.py b/infer.py index 07dcd22..cf59bba 100644 --- a/infer.py +++ b/infer.py @@ -32,7 +32,7 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from util import load_config, load_embeddings_from_dir +from util.files import load_config, load_embeddings_from_dir torch.backends.cuda.matmul.allow_tf32 = True diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2251848..a6b31d8 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -24,7 +24,9 @@ from diffusers import ( from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer + from models.clip.util import unify_input_ids, get_extended_embeddings +from util.noise import perlin_noise logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -304,23 +306,18 @@ class VlpnStableDiffusion(DiffusionPipeline): return timesteps, num_inference_steps - t_start - def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): - offset = (max_offset * (2 * torch.rand( + def prepare_image(self, batch_size, width, height, dtype, device, generator=None): + max = 0.4 + offset = max * (2 * torch.rand( (batch_size, 1, 1, 1), dtype=dtype, device=device, generator=generator - ) - 1)).expand(batch_size, 1, 2, 2) - image = F.interpolate( - torch.normal( - mean=offset, - std=0.3, - generator=generator - ).clamp(-1, 1), - size=(width, height), - mode="bicubic" + ) - 1) + noise = perlin_noise( + batch_size, width, height, res=3, octaves=3, generator=generator, dtype=dtype, device=device ).expand(batch_size, 3, width, height) - return image + return ((1 + max) * noise + max * offset).clamp(-1, 1) def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -384,7 +381,6 @@ class VlpnStableDiffusion(DiffusionPipeline): eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - max_init_offset: float = 0.7, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -474,7 +470,6 @@ class VlpnStableDiffusion(DiffusionPipeline): batch_size * num_images_per_prompt, width, height, - max_init_offset, prompt_embeds.dtype, device, generator diff --git a/train_dreambooth.py b/train_dreambooth.py index 9b91172..dd2bf6e 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -13,7 +13,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify -from util import load_config, load_embeddings_from_dir +from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, get_models from training.lr import plot_metrics diff --git a/train_lora.py b/train_lora.py index e213e3d..6e72376 100644 --- a/train_lora.py +++ b/train_lora.py @@ -15,7 +15,7 @@ from slugify import slugify from diffusers.loaders import AttnProcsLayers from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor -from util import load_config, load_embeddings_from_dir +from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, get_models from training.lr import plot_metrics diff --git a/train_ti.py b/train_ti.py index c139cc0..b9d6e56 100644 --- a/train_ti.py +++ b/train_ti.py @@ -12,7 +12,7 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from slugify import slugify -from util import load_config, load_embeddings_from_dir +from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.lr import plot_metrics diff --git a/util.py b/util.py deleted file mode 100644 index 2712525..0000000 --- a/util.py +++ /dev/null @@ -1,45 +0,0 @@ -from pathlib import Path -import json - -from models.clip.embeddings import ManagedCLIPTextEmbeddings -from models.clip.tokenizer import MultiCLIPTokenizer - -from safetensors import safe_open - - -def load_config(filename): - with open(filename, 'rt') as f: - config = json.load(f) - - args = config["args"] - - if "base" in config: - args = load_config(Path(filename).parent / config["base"]) | args - - return args - - -def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): - if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - return [] - - filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] - tokens = [filename.stem for filename in filenames] - - new_ids: list[list[int]] = [] - new_embeds = [] - - for filename in filenames: - with safe_open(filename, framework="pt", device="cpu") as file: - embed = file.get_tensor("embed") - - added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) - new_ids.append(added) - new_embeds.append(embed) - - embeddings.resize(len(tokenizer)) - - for (new_id, embeds) in zip(new_ids, new_embeds): - embeddings.add_embed(new_id, embeds) - - return tokens, new_ids diff --git a/util/files.py b/util/files.py new file mode 100644 index 0000000..2712525 --- /dev/null +++ b/util/files.py @@ -0,0 +1,45 @@ +from pathlib import Path +import json + +from models.clip.embeddings import ManagedCLIPTextEmbeddings +from models.clip.tokenizer import MultiCLIPTokenizer + +from safetensors import safe_open + + +def load_config(filename): + with open(filename, 'rt') as f: + config = json.load(f) + + args = config["args"] + + if "base" in config: + args = load_config(Path(filename).parent / config["base"]) | args + + return args + + +def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + return [] + + filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] + tokens = [filename.stem for filename in filenames] + + new_ids: list[list[int]] = [] + new_embeds = [] + + for filename in filenames: + with safe_open(filename, framework="pt", device="cpu") as file: + embed = file.get_tensor("embed") + + added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) + new_ids.append(added) + new_embeds.append(embed) + + embeddings.resize(len(tokenizer)) + + for (new_id, embeds) in zip(new_ids, new_embeds): + embeddings.add_embed(new_id, embeds) + + return tokens, new_ids diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py @@ -0,0 +1,57 @@ +import math +import torch + +# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 + + +def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid( + torch.arange(0, res[0], delta[0], dtype=dtype, device=device), + torch.arange(0, res[1], delta[1], dtype=dtype, device=device), + indexing='ij' + ), dim=-1) % 1 + angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + def tile_grads(slice1, slice2): return gradients[ + slice1[0]:slice1[1], + slice2[0]:slice2[1] + ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + + def dot(grad, shift): return (torch.stack(( + grid[:shape[0], :shape[1], 0] + shift[0], + grid[:shape[0], :shape[1], 1] + shift[1] + ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[:shape[0], :shape[1]]) + return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): + noise = torch.zeros(shape, dtype=dtype, device=device) + frequency = 1 + amplitude = 1 + for _ in range(int(octaves)): + noise += amplitude * rand_perlin_2d( + shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator + ) + frequency *= 2 + amplitude *= persistence + return noise + + +def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): + return torch.stack([ + rand_perlin_2d_octaves( + (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator + ).unsqueeze(0) + for _ + in range(batch_size) + ]) -- cgit v1.2.3-70-g09d2