diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/files.py | 45 | ||||
-rw-r--r-- | util/noise.py | 57 |
2 files changed, 102 insertions, 0 deletions
diff --git a/util/files.py b/util/files.py new file mode 100644 index 0000000..2712525 --- /dev/null +++ b/util/files.py | |||
@@ -0,0 +1,45 @@ | |||
1 | from pathlib import Path | ||
2 | import json | ||
3 | |||
4 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
5 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
6 | |||
7 | from safetensors import safe_open | ||
8 | |||
9 | |||
10 | def load_config(filename): | ||
11 | with open(filename, 'rt') as f: | ||
12 | config = json.load(f) | ||
13 | |||
14 | args = config["args"] | ||
15 | |||
16 | if "base" in config: | ||
17 | args = load_config(Path(filename).parent / config["base"]) | args | ||
18 | |||
19 | return args | ||
20 | |||
21 | |||
22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | ||
23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
24 | return [] | ||
25 | |||
26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | ||
27 | tokens = [filename.stem for filename in filenames] | ||
28 | |||
29 | new_ids: list[list[int]] = [] | ||
30 | new_embeds = [] | ||
31 | |||
32 | for filename in filenames: | ||
33 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
34 | embed = file.get_tensor("embed") | ||
35 | |||
36 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) | ||
37 | new_ids.append(added) | ||
38 | new_embeds.append(embed) | ||
39 | |||
40 | embeddings.resize(len(tokenizer)) | ||
41 | |||
42 | for (new_id, embeds) in zip(new_ids, new_embeds): | ||
43 | embeddings.add_embed(new_id, embeds) | ||
44 | |||
45 | return tokens, new_ids | ||
diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py | |||
@@ -0,0 +1,57 @@ | |||
1 | import math | ||
2 | import torch | ||
3 | |||
4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | ||
5 | |||
6 | |||
7 | def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): | ||
8 | delta = (res[0] / shape[0], res[1] / shape[1]) | ||
9 | d = (shape[0] // res[0], shape[1] // res[1]) | ||
10 | |||
11 | grid = torch.stack(torch.meshgrid( | ||
12 | torch.arange(0, res[0], delta[0], dtype=dtype, device=device), | ||
13 | torch.arange(0, res[1], delta[1], dtype=dtype, device=device), | ||
14 | indexing='ij' | ||
15 | ), dim=-1) % 1 | ||
16 | angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) | ||
17 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) | ||
18 | |||
19 | def tile_grads(slice1, slice2): return gradients[ | ||
20 | slice1[0]:slice1[1], | ||
21 | slice2[0]:slice2[1] | ||
22 | ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) | ||
23 | |||
24 | def dot(grad, shift): return (torch.stack(( | ||
25 | grid[:shape[0], :shape[1], 0] + shift[0], | ||
26 | grid[:shape[0], :shape[1], 1] + shift[1] | ||
27 | ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) | ||
28 | |||
29 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) | ||
30 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) | ||
31 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) | ||
32 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) | ||
33 | t = fade(grid[:shape[0], :shape[1]]) | ||
34 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) | ||
35 | |||
36 | |||
37 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): | ||
38 | noise = torch.zeros(shape, dtype=dtype, device=device) | ||
39 | frequency = 1 | ||
40 | amplitude = 1 | ||
41 | for _ in range(int(octaves)): | ||
42 | noise += amplitude * rand_perlin_2d( | ||
43 | shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator | ||
44 | ) | ||
45 | frequency *= 2 | ||
46 | amplitude *= persistence | ||
47 | return noise | ||
48 | |||
49 | |||
50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | ||
51 | return torch.stack([ | ||
52 | rand_perlin_2d_octaves( | ||
53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | ||
54 | ).unsqueeze(0) | ||
55 | for _ | ||
56 | in range(batch_size) | ||
57 | ]) | ||