summaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/noise.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/util/noise.py b/util/noise.py
index 38ab172..3c4f82d 100644
--- a/util/noise.py
+++ b/util/noise.py
@@ -1,4 +1,5 @@
1import math 1import math
2
2import torch 3import torch
3 4
4# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 5# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
@@ -47,11 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d
47 return noise 48 return noise
48 49
49 50
50def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): 51def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None):
51 return torch.stack([ 52 return torch.stack([
52 rand_perlin_2d_octaves( 53 torch.stack([
53 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator 54 rand_perlin_2d_octaves(
54 ).unsqueeze(0) 55 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator
55 for _ 56 )
56 in range(batch_size) 57 for _ in range(channels)
58 ])
59 for _ in range(batch_size)
57 ]) 60 ])