diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/noise.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/util/noise.py b/util/noise.py index 38ab172..3c4f82d 100644 --- a/util/noise.py +++ b/util/noise.py | |||
@@ -1,4 +1,5 @@ | |||
1 | import math | 1 | import math |
2 | |||
2 | import torch | 3 | import torch |
3 | 4 | ||
4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | 5 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 |
@@ -47,11 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d | |||
47 | return noise | 48 | return noise |
48 | 49 | ||
49 | 50 | ||
50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | 51 | def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): |
51 | return torch.stack([ | 52 | return torch.stack([ |
52 | rand_perlin_2d_octaves( | 53 | torch.stack([ |
53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | 54 | rand_perlin_2d_octaves( |
54 | ).unsqueeze(0) | 55 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator |
55 | for _ | 56 | ) |
56 | in range(batch_size) | 57 | for _ in range(channels) |
58 | ]) | ||
59 | for _ in range(batch_size) | ||
57 | ]) | 60 | ]) |