summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-13 21:19:46 +0100
committerVolpeon <git@volpeon.ink>2023-02-13 21:19:46 +0100
commite9b7143c790ebc8b7b18c41f51d0e799ac84a337 (patch)
treeee9f6e3f192706ba96ec3aa20b24836b5ee0e673 /training
parentUpdate (diff)
downloadtextual-inversion-diff-e9b7143c790ebc8b7b18c41f51d0e799ac84a337.tar.gz
textual-inversion-diff-e9b7143c790ebc8b7b18c41f51d0e799ac84a337.tar.bz2
textual-inversion-diff-e9b7143c790ebc8b7b18c41f51d0e799ac84a337.zip
Better noise generation during training: https://www.crosslabs.org/blog/diffusion-with-offset-noise
Diffstat (limited to 'training')
-rw-r--r--training/functional.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py
index 83e70e2..62b8260 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -157,6 +157,7 @@ def save_samples(
157 width=image_size, 157 width=image_size,
158 generator=gen, 158 generator=gen,
159 guidance_scale=guidance_scale, 159 guidance_scale=guidance_scale,
160 sag_scale=0,
160 num_inference_steps=num_steps, 161 num_inference_steps=num_steps,
161 output_type='pil' 162 output_type='pil'
162 ).images 163 ).images
@@ -273,6 +274,12 @@ def loss_step(
273 layout=latents.layout, 274 layout=latents.layout,
274 device=latents.device, 275 device=latents.device,
275 generator=generator 276 generator=generator
277 ) + 0.1 * torch.randn(
278 latents.shape[0], latents.shape[1], 1, 1,
279 dtype=latents.dtype,
280 layout=latents.layout,
281 device=latents.device,
282 generator=generator
276 ) 283 )
277 bsz = latents.shape[0] 284 bsz = latents.shape[0]
278 # Sample a random timestep for each image 285 # Sample a random timestep for each image