1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
|
import math
import torch
# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(
torch.arange(0, res[0], delta[0], dtype=dtype, device=device),
torch.arange(0, res[1], delta[1], dtype=dtype, device=device),
indexing='ij'
), dim=-1) % 1
angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
def tile_grads(slice1, slice2): return gradients[
slice1[0]:slice1[1],
slice2[0]:slice2[1]
].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
def dot(grad, shift): return (torch.stack((
grid[:shape[0], :shape[1], 0] + shift[0],
grid[:shape[0], :shape[1], 1] + shift[1]
), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None):
noise = torch.zeros(shape, dtype=dtype, device=device)
frequency = 1
amplitude = 1
for _ in range(int(octaves)):
noise += amplitude * rand_perlin_2d(
shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator
)
frequency *= 2
amplitude *= persistence
return noise
def perlin_noise(shape: tuple[int, int, int, int], res=8, octaves=1, dtype=None, device=None, generator=None):
return torch.stack([
torch.stack([
rand_perlin_2d_octaves(
(shape[2], shape[3]), (res, res), octaves, dtype=dtype, device=device, generator=generator
)
for _ in range(shape[1])
])
for _ in range(shape[0])
])
|