summaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-04 09:46:41 +0100
committerVolpeon <git@volpeon.ink>2023-03-04 09:46:41 +0100
commit73910b7f55244ce787fc6a3e6af09240ef0cdfd3 (patch)
tree3ef927578fc54b59ab6ff1bd00c3f804c0b9a7bf /util
parentPipeline: Improved initial image generation (diff)
downloadtextual-inversion-diff-73910b7f55244ce787fc6a3e6af09240ef0cdfd3.tar.gz
textual-inversion-diff-73910b7f55244ce787fc6a3e6af09240ef0cdfd3.tar.bz2
textual-inversion-diff-73910b7f55244ce787fc6a3e6af09240ef0cdfd3.zip
Pipeline: Perlin noise for init image
Diffstat (limited to 'util')
-rw-r--r--util/files.py45
-rw-r--r--util/noise.py57
2 files changed, 102 insertions, 0 deletions
diff --git a/util/files.py b/util/files.py
new file mode 100644
index 0000000..2712525
--- /dev/null
+++ b/util/files.py
@@ -0,0 +1,45 @@
1from pathlib import Path
2import json
3
4from models.clip.embeddings import ManagedCLIPTextEmbeddings
5from models.clip.tokenizer import MultiCLIPTokenizer
6
7from safetensors import safe_open
8
9
10def load_config(filename):
11 with open(filename, 'rt') as f:
12 config = json.load(f)
13
14 args = config["args"]
15
16 if "base" in config:
17 args = load_config(Path(filename).parent / config["base"]) | args
18
19 return args
20
21
22def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path):
23 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
24 return []
25
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
27 tokens = [filename.stem for filename in filenames]
28
29 new_ids: list[list[int]] = []
30 new_embeds = []
31
32 for filename in filenames:
33 with safe_open(filename, framework="pt", device="cpu") as file:
34 embed = file.get_tensor("embed")
35
36 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
37 new_ids.append(added)
38 new_embeds.append(embed)
39
40 embeddings.resize(len(tokenizer))
41
42 for (new_id, embeds) in zip(new_ids, new_embeds):
43 embeddings.add_embed(new_id, embeds)
44
45 return tokens, new_ids
diff --git a/util/noise.py b/util/noise.py
new file mode 100644
index 0000000..38ab172
--- /dev/null
+++ b/util/noise.py
@@ -0,0 +1,57 @@
1import math
2import torch
3
4# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
5
6
7def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None):
8 delta = (res[0] / shape[0], res[1] / shape[1])
9 d = (shape[0] // res[0], shape[1] // res[1])
10
11 grid = torch.stack(torch.meshgrid(
12 torch.arange(0, res[0], delta[0], dtype=dtype, device=device),
13 torch.arange(0, res[1], delta[1], dtype=dtype, device=device),
14 indexing='ij'
15 ), dim=-1) % 1
16 angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device)
17 gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
18
19 def tile_grads(slice1, slice2): return gradients[
20 slice1[0]:slice1[1],
21 slice2[0]:slice2[1]
22 ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
23
24 def dot(grad, shift): return (torch.stack((
25 grid[:shape[0], :shape[1], 0] + shift[0],
26 grid[:shape[0], :shape[1], 1] + shift[1]
27 ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
28
29 n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
30 n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
31 n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
32 n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
33 t = fade(grid[:shape[0], :shape[1]])
34 return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
35
36
37def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None):
38 noise = torch.zeros(shape, dtype=dtype, device=device)
39 frequency = 1
40 amplitude = 1
41 for _ in range(int(octaves)):
42 noise += amplitude * rand_perlin_2d(
43 shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator
44 )
45 frequency *= 2
46 amplitude *= persistence
47 return noise
48
49
50def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None):
51 return torch.stack([
52 rand_perlin_2d_octaves(
53 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator
54 ).unsqueeze(0)
55 for _
56 in range(batch_size)
57 ])