diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-01 14:10:23 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-01 14:10:23 +0100 |
| commit | 51394430b6b142eb21641f251b5fe32cdf802ab8 (patch) | |
| tree | d47033374f64a6c079571169aff57e8aff7b57d7 /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.tar.gz textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.tar.bz2 textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.zip | |
Changed low freq noise
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 33 |
1 files changed, 10 insertions, 23 deletions
diff --git a/training/functional.py b/training/functional.py index 990c4cd..8ea40bb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -268,35 +268,22 @@ def loss_step( | |||
| 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
| 269 | 269 | ||
| 270 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
| 271 | if low_freq_noise == 0: | 271 | noise = torch.randn( |
| 272 | noise = torch.randn( | 272 | latents.shape, |
| 273 | latents.shape, | 273 | dtype=latents.dtype, |
| 274 | dtype=latents.dtype, | 274 | layout=latents.layout, |
| 275 | layout=latents.layout, | 275 | device=latents.device, |
| 276 | device=latents.device, | 276 | generator=generator |
| 277 | generator=generator | 277 | ) |
| 278 | ) | 278 | |
| 279 | else: | 279 | if low_freq_noise != 0: |
| 280 | noise = (1 - low_freq_noise) * torch.randn( | 280 | noise *= 1 - low_freq_noise + low_freq_noise * torch.randn( |
| 281 | latents.shape, | ||
| 282 | dtype=latents.dtype, | ||
| 283 | layout=latents.layout, | ||
| 284 | device=latents.device, | ||
| 285 | generator=generator | ||
| 286 | ) + low_freq_noise * torch.randn( | ||
| 287 | latents.shape[0], latents.shape[1], 1, 1, | 281 | latents.shape[0], latents.shape[1], 1, 1, |
| 288 | dtype=latents.dtype, | 282 | dtype=latents.dtype, |
| 289 | layout=latents.layout, | 283 | layout=latents.layout, |
| 290 | device=latents.device, | 284 | device=latents.device, |
| 291 | generator=generator | 285 | generator=generator |
| 292 | ) | 286 | ) |
| 293 | # noise += low_freq_noise * torch.randn( | ||
| 294 | # bsz, 1, 1, 1, | ||
| 295 | # dtype=latents.dtype, | ||
| 296 | # layout=latents.layout, | ||
| 297 | # device=latents.device, | ||
| 298 | # generator=generator | ||
| 299 | # ) | ||
| 300 | 287 | ||
| 301 | # Sample a random timestep for each image | 288 | # Sample a random timestep for each image |
| 302 | timesteps = torch.randint( | 289 | timesteps = torch.randint( |
