diff options
Diffstat (limited to 'util')
| -rw-r--r-- | util/files.py | 45 | ||||
| -rw-r--r-- | util/noise.py | 57 |
2 files changed, 102 insertions, 0 deletions
diff --git a/util/files.py b/util/files.py new file mode 100644 index 0000000..2712525 --- /dev/null +++ b/util/files.py | |||
| @@ -0,0 +1,45 @@ | |||
| 1 | from pathlib import Path | ||
| 2 | import json | ||
| 3 | |||
| 4 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
| 5 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 6 | |||
| 7 | from safetensors import safe_open | ||
| 8 | |||
| 9 | |||
| 10 | def load_config(filename): | ||
| 11 | with open(filename, 'rt') as f: | ||
| 12 | config = json.load(f) | ||
| 13 | |||
| 14 | args = config["args"] | ||
| 15 | |||
| 16 | if "base" in config: | ||
| 17 | args = load_config(Path(filename).parent / config["base"]) | args | ||
| 18 | |||
| 19 | return args | ||
| 20 | |||
| 21 | |||
| 22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | ||
| 23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 24 | return [] | ||
| 25 | |||
| 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | ||
| 27 | tokens = [filename.stem for filename in filenames] | ||
| 28 | |||
| 29 | new_ids: list[list[int]] = [] | ||
| 30 | new_embeds = [] | ||
| 31 | |||
| 32 | for filename in filenames: | ||
| 33 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 34 | embed = file.get_tensor("embed") | ||
| 35 | |||
| 36 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
| 37 | new_ids.append(added) | ||
| 38 | new_embeds.append(embed) | ||
| 39 | |||
| 40 | embeddings.resize(len(tokenizer)) | ||
| 41 | |||
| 42 | for (new_id, embeds) in zip(new_ids, new_embeds): | ||
| 43 | embeddings.add_embed(new_id, embeds) | ||
| 44 | |||
| 45 | return tokens, new_ids | ||
diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py | |||
| @@ -0,0 +1,57 @@ | |||
| 1 | import math | ||
| 2 | import torch | ||
| 3 | |||
| 4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | ||
| 5 | |||
| 6 | |||
| 7 | def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): | ||
| 8 | delta = (res[0] / shape[0], res[1] / shape[1]) | ||
| 9 | d = (shape[0] // res[0], shape[1] // res[1]) | ||
| 10 | |||
| 11 | grid = torch.stack(torch.meshgrid( | ||
| 12 | torch.arange(0, res[0], delta[0], dtype=dtype, device=device), | ||
| 13 | torch.arange(0, res[1], delta[1], dtype=dtype, device=device), | ||
| 14 | indexing='ij' | ||
| 15 | ), dim=-1) % 1 | ||
| 16 | angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) | ||
| 17 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) | ||
| 18 | |||
| 19 | def tile_grads(slice1, slice2): return gradients[ | ||
| 20 | slice1[0]:slice1[1], | ||
| 21 | slice2[0]:slice2[1] | ||
| 22 | ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) | ||
| 23 | |||
| 24 | def dot(grad, shift): return (torch.stack(( | ||
| 25 | grid[:shape[0], :shape[1], 0] + shift[0], | ||
| 26 | grid[:shape[0], :shape[1], 1] + shift[1] | ||
| 27 | ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) | ||
| 28 | |||
| 29 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) | ||
| 30 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) | ||
| 31 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) | ||
| 32 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) | ||
| 33 | t = fade(grid[:shape[0], :shape[1]]) | ||
| 34 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) | ||
| 35 | |||
| 36 | |||
| 37 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): | ||
| 38 | noise = torch.zeros(shape, dtype=dtype, device=device) | ||
| 39 | frequency = 1 | ||
| 40 | amplitude = 1 | ||
| 41 | for _ in range(int(octaves)): | ||
| 42 | noise += amplitude * rand_perlin_2d( | ||
| 43 | shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator | ||
| 44 | ) | ||
| 45 | frequency *= 2 | ||
| 46 | amplitude *= persistence | ||
| 47 | return noise | ||
| 48 | |||
| 49 | |||
| 50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | ||
| 51 | return torch.stack([ | ||
| 52 | rand_perlin_2d_octaves( | ||
| 53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | ||
| 54 | ).unsqueeze(0) | ||
| 55 | for _ | ||
| 56 | in range(batch_size) | ||
| 57 | ]) | ||
