diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py index 83e70e2..62b8260 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -157,6 +157,7 @@ def save_samples( | |||
157 | width=image_size, | 157 | width=image_size, |
158 | generator=gen, | 158 | generator=gen, |
159 | guidance_scale=guidance_scale, | 159 | guidance_scale=guidance_scale, |
160 | sag_scale=0, | ||
160 | num_inference_steps=num_steps, | 161 | num_inference_steps=num_steps, |
161 | output_type='pil' | 162 | output_type='pil' |
162 | ).images | 163 | ).images |
@@ -273,6 +274,12 @@ def loss_step( | |||
273 | layout=latents.layout, | 274 | layout=latents.layout, |
274 | device=latents.device, | 275 | device=latents.device, |
275 | generator=generator | 276 | generator=generator |
277 | ) + 0.1 * torch.randn( | ||
278 | latents.shape[0], latents.shape[1], 1, 1, | ||
279 | dtype=latents.dtype, | ||
280 | layout=latents.layout, | ||
281 | device=latents.device, | ||
282 | generator=generator | ||
276 | ) | 283 | ) |
277 | bsz = latents.shape[0] | 284 | bsz = latents.shape[0] |
278 | # Sample a random timestep for each image | 285 | # Sample a random timestep for each image |