summaryrefslogtreecommitdiffstats
path: root/util/noise.py
blob: e3ebdb206fa6cda820b11d4cc5c2d24d2b0b39c7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import math

import torch

# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57


def rand_perlin_2d(shape, res, fade=lambda t: 6*t**5 - 15*t**4 + 10*t**3, dtype=None, device=None, generator=None):
    delta = (res[0] / shape[0], res[1] / shape[1])
    d = (shape[0] // res[0], shape[1] // res[1])

    grid = torch.stack(torch.meshgrid(
        torch.arange(0, res[0], delta[0], dtype=dtype, device=device),
        torch.arange(0, res[1], delta[1], dtype=dtype, device=device),
        indexing='ij'
    ), dim=-1) % 1
    angles = 2*math.pi*torch.rand(res[0]+1, res[1]+1, generator=generator, dtype=dtype, device=device)
    gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)

    def tile_grads(slice1, slice2): return gradients[
        slice1[0]:slice1[1],
        slice2[0]:slice2[1]
    ].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)

    def dot(grad, shift): return (torch.stack((
        grid[:shape[0], :shape[1], 0] + shift[0],
        grid[:shape[0], :shape[1], 1] + shift[1]
    ), dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)

    n00 = dot(tile_grads([0, -1], [0, -1]), [0,  0])
    n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
    n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
    n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
    t = fade(grid[:shape[0], :shape[1]])
    return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])


def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, device=None, generator=None):
    noise = torch.zeros(shape, dtype=dtype, device=device)
    frequency = 1
    amplitude = 1
    for _ in range(int(octaves)):
        noise += amplitude * rand_perlin_2d(
            shape, (frequency*res[0], frequency*res[1]), dtype=dtype, device=device, generator=generator
        )
        frequency *= 2
        amplitude *= persistence
    return noise


def perlin_noise(shape: tuple[int, int, int, int], res=8, octaves=1, dtype=None, device=None, generator=None):
    return torch.stack([
        torch.stack([
            rand_perlin_2d_octaves(
                (shape[2], shape[3]), (res, res), octaves, dtype=dtype, device=device, generator=generator
            )
            for _ in range(shape[1])
        ])
        for _ in range(shape[0])
    ])