From 73910b7f55244ce787fc6a3e6af09240ef0cdfd3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 4 Mar 2023 09:46:41 +0100 Subject: Pipeline: Perlin noise for init image --- util/files.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ util/noise.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 util/files.py create mode 100644 util/noise.py (limited to 'util') diff --git a/util/files.py b/util/files.py new file mode 100644 index 0000000..2712525 --- /dev/null +++ b/util/files.py @@ -0,0 +1,45 @@ +from pathlib import Path +import json + +from models.clip.embeddings import ManagedCLIPTextEmbeddings +from models.clip.tokenizer import MultiCLIPTokenizer + +from safetensors import safe_open + + +def load_config(filename): + with open(filename, 'rt') as f: + config = json.load(f) + + args = config["args"] + + if "base" in config: + args = load_config(Path(filename).parent / config["base"]) | args + + return args + + +def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + return [] + + filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] + tokens = [filename.stem for filename in filenames] + + new_ids: list[list[int]] = [] + new_embeds = [] + + for filename in filenames: + with safe_open(filename, framework="pt", device="cpu") as file: + embed = file.get_tensor("embed") + + added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) + new_ids.append(added) + new_embeds.append(embed) + + embeddings.resize(len(tokenizer)) + + for (new_id, embeds) in zip(new_ids, new_embeds): + embeddings.add_embed(new_id, embeds) + + return tokens, new_ids diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py @@ -0,0 +1,57 @@ +import math +import torch + +# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 + + +def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): + delta = (res[0] / shape[0], res[1] / shape[1]) + d = (shape[0] // res[0], shape[1] // res[1]) + + grid = torch.stack(torch.meshgrid( + torch.arange(0, res[0], delta[0], dtype=dtype, device=device), + torch.arange(0, res[1], delta[1], dtype=dtype, device=device), + indexing='ij' + ), dim=-1) % 1 + angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) + gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) + + def tile_grads(slice1, slice2): return gradients[ + slice1[0]:slice1[1], + slice2[0]:slice2[1] + ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) + + def dot(grad, shift): return (torch.stack(( + grid[:shape[0], :shape[1], 0] + shift[0], + grid[:shape[0], :shape[1], 1] + shift[1] + ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) + + n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) + n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) + n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) + n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) + t = fade(grid[:shape[0], :shape[1]]) + return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) + + +def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): + noise = torch.zeros(shape, dtype=dtype, device=device) + frequency = 1 + amplitude = 1 + for _ in range(int(octaves)): + noise += amplitude * rand_perlin_2d( + shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator + ) + frequency *= 2 + amplitude *= persistence + return noise + + +def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): + return torch.stack([ + rand_perlin_2d_octaves( + (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator + ).unsqueeze(0) + for _ + in range(batch_size) + ]) -- cgit v1.2.3-70-g09d2