diff options
Diffstat (limited to 'util/noise.py')
-rw-r--r-- | util/noise.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/util/noise.py b/util/noise.py new file mode 100644 index 0000000..38ab172 --- /dev/null +++ b/util/noise.py | |||
@@ -0,0 +1,57 @@ | |||
1 | import math | ||
2 | import torch | ||
3 | |||
4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | ||
5 | |||
6 | |||
7 | def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None): | ||
8 | delta = (res[0] / shape[0], res[1] / shape[1]) | ||
9 | d = (shape[0] // res[0], shape[1] // res[1]) | ||
10 | |||
11 | grid = torch.stack(torch.meshgrid( | ||
12 | torch.arange(0, res[0], delta[0], dtype=dtype, device=device), | ||
13 | torch.arange(0, res[1], delta[1], dtype=dtype, device=device), | ||
14 | indexing='ij' | ||
15 | ), dim=-1) % 1 | ||
16 | angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device) | ||
17 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) | ||
18 | |||
19 | def tile_grads(slice1, slice2): return gradients[ | ||
20 | slice1[0]:slice1[1], | ||
21 | slice2[0]:slice2[1] | ||
22 | ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1) | ||
23 | |||
24 | def dot(grad, shift): return (torch.stack(( | ||
25 | grid[:shape[0], :shape[1], 0] + shift[0], | ||
26 | grid[:shape[0], :shape[1], 1] + shift[1] | ||
27 | ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1) | ||
28 | |||
29 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) | ||
30 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) | ||
31 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) | ||
32 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) | ||
33 | t = fade(grid[:shape[0], :shape[1]]) | ||
34 | return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]) | ||
35 | |||
36 | |||
37 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None): | ||
38 | noise = torch.zeros(shape, dtype=dtype, device=device) | ||
39 | frequency = 1 | ||
40 | amplitude = 1 | ||
41 | for _ in range(int(octaves)): | ||
42 | noise += amplitude * rand_perlin_2d( | ||
43 | shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator | ||
44 | ) | ||
45 | frequency *= 2 | ||
46 | amplitude *= persistence | ||
47 | return noise | ||
48 | |||
49 | |||
50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | ||
51 | return torch.stack([ | ||
52 | rand_perlin_2d_octaves( | ||
53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | ||
54 | ).unsqueeze(0) | ||
55 | for _ | ||
56 | in range(batch_size) | ||
57 | ]) | ||