summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-01 14:10:23 +0100
committerVolpeon <git@volpeon.ink>2023-03-01 14:10:23 +0100
commit51394430b6b142eb21641f251b5fe32cdf802ab8 (patch)
treed47033374f64a6c079571169aff57e8aff7b57d7
parentUpdate (diff)
downloadtextual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.tar.gz
textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.tar.bz2
textual-inversion-diff-51394430b6b142eb21641f251b5fe32cdf802ab8.zip
Changed low freq noise
-rw-r--r--training/functional.py33
1 files changed, 10 insertions, 23 deletions
diff --git a/training/functional.py b/training/functional.py
index 990c4cd..8ea40bb 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -268,35 +268,22 @@ def loss_step(
268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
269 269
270 # Sample noise that we'll add to the latents 270 # Sample noise that we'll add to the latents
271 if low_freq_noise == 0: 271 noise = torch.randn(
272 noise = torch.randn( 272 latents.shape,
273 latents.shape, 273 dtype=latents.dtype,
274 dtype=latents.dtype, 274 layout=latents.layout,
275 layout=latents.layout, 275 device=latents.device,
276 device=latents.device, 276 generator=generator
277 generator=generator 277 )
278 ) 278
279 else: 279 if low_freq_noise != 0:
280 noise = (1 - low_freq_noise) * torch.randn( 280 noise *= 1 - low_freq_noise + low_freq_noise * torch.randn(
281 latents.shape,
282 dtype=latents.dtype,
283 layout=latents.layout,
284 device=latents.device,
285 generator=generator
286 ) + low_freq_noise * torch.randn(
287 latents.shape[0], latents.shape[1], 1, 1, 281 latents.shape[0], latents.shape[1], 1, 1,
288 dtype=latents.dtype, 282 dtype=latents.dtype,
289 layout=latents.layout, 283 layout=latents.layout,
290 device=latents.device, 284 device=latents.device,
291 generator=generator 285 generator=generator
292 ) 286 )
293 # noise += low_freq_noise * torch.randn(
294 # bsz, 1, 1, 1,
295 # dtype=latents.dtype,
296 # layout=latents.layout,
297 # device=latents.device,
298 # generator=generator
299 # )
300 287
301 # Sample a random timestep for each image 288 # Sample a random timestep for each image
302 timesteps = torch.randint( 289 timesteps = torch.randint(