diff options
author | Volpeon <git@volpeon.ink> | 2023-03-01 14:10:23 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-01 14:10:23 +0100 |
commit | 51394430b6b142eb21641f251b5fe32cdf802ab8 (patch) | |
tree | d47033374f64a6c079571169aff57e8aff7b57d7 /training | |
parent | Update (diff) | |
download | textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.tar.gz textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.tar.bz2 textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.zip |
Changed low freq noise
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 33 |
1 files changed, 10 insertions, 23 deletions
diff --git a/training/functional.py b/training/functional.py index 990c4cd..8ea40bb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -268,35 +268,22 @@ def loss_step( | |||
268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
269 | 269 | ||
270 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
271 | if low_freq_noise == 0: | 271 | noise = torch.randn( |
272 | noise = torch.randn( | 272 | latents.shape, |
273 | latents.shape, | 273 | dtype=latents.dtype, |
274 | dtype=latents.dtype, | 274 | layout=latents.layout, |
275 | layout=latents.layout, | 275 | device=latents.device, |
276 | device=latents.device, | 276 | generator=generator |
277 | generator=generator | 277 | ) |
278 | ) | 278 | |
279 | else: | 279 | if low_freq_noise != 0: |
280 | noise = (1 - low_freq_noise) * torch.randn( | 280 | noise *= 1 - low_freq_noise + low_freq_noise * torch.randn( |
281 | latents.shape, | ||
282 | dtype=latents.dtype, | ||
283 | layout=latents.layout, | ||
284 | device=latents.device, | ||
285 | generator=generator | ||
286 | ) + low_freq_noise * torch.randn( | ||
287 | latents.shape[0], latents.shape[1], 1, 1, | 281 | latents.shape[0], latents.shape[1], 1, 1, |
288 | dtype=latents.dtype, | 282 | dtype=latents.dtype, |
289 | layout=latents.layout, | 283 | layout=latents.layout, |
290 | device=latents.device, | 284 | device=latents.device, |
291 | generator=generator | 285 | generator=generator |
292 | ) | 286 | ) |
293 | # noise += low_freq_noise * torch.randn( | ||
294 | # bsz, 1, 1, 1, | ||
295 | # dtype=latents.dtype, | ||
296 | # layout=latents.layout, | ||
297 | # device=latents.device, | ||
298 | # generator=generator | ||
299 | # ) | ||
300 | 287 | ||
301 | # Sample a random timestep for each image | 288 | # Sample a random timestep for each image |
302 | timesteps = torch.randint( | 289 | timesteps = torch.randint( |