diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/noise.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/util/noise.py b/util/noise.py index 3c4f82d..e3ebdb2 100644 --- a/util/noise.py +++ b/util/noise.py | |||
@@ -48,13 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d | |||
48 | return noise | 48 | return noise |
49 | 49 | ||
50 | 50 | ||
51 | def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | 51 | def perlin_noise(shape: tuple[int, int, int, int], res=8, octaves=1, dtype=None, device=None, generator=None): |
52 | return torch.stack([ | 52 | return torch.stack([ |
53 | torch.stack([ | 53 | torch.stack([ |
54 | rand_perlin_2d_octaves( | 54 | rand_perlin_2d_octaves( |
55 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | 55 | (shape[2], shape[3]), (res, res), octaves, dtype=dtype, device=device, generator=generator |
56 | ) | 56 | ) |
57 | for _ in range(channels) | 57 | for _ in range(shape[1]) |
58 | ]) | 58 | ]) |
59 | for _ in range(batch_size) | 59 | for _ in range(shape[0]) |
60 | ]) | 60 | ]) |