diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 7 | 
1 files changed, 7 insertions, 0 deletions
| diff --git a/training/functional.py b/training/functional.py index 83e70e2..62b8260 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -157,6 +157,7 @@ def save_samples( | |||
| 157 | width=image_size, | 157 | width=image_size, | 
| 158 | generator=gen, | 158 | generator=gen, | 
| 159 | guidance_scale=guidance_scale, | 159 | guidance_scale=guidance_scale, | 
| 160 | sag_scale=0, | ||
| 160 | num_inference_steps=num_steps, | 161 | num_inference_steps=num_steps, | 
| 161 | output_type='pil' | 162 | output_type='pil' | 
| 162 | ).images | 163 | ).images | 
| @@ -273,6 +274,12 @@ def loss_step( | |||
| 273 | layout=latents.layout, | 274 | layout=latents.layout, | 
| 274 | device=latents.device, | 275 | device=latents.device, | 
| 275 | generator=generator | 276 | generator=generator | 
| 277 | ) + 0.1 * torch.randn( | ||
| 278 | latents.shape[0], latents.shape[1], 1, 1, | ||
| 279 | dtype=latents.dtype, | ||
| 280 | layout=latents.layout, | ||
| 281 | device=latents.device, | ||
| 282 | generator=generator | ||
| 276 | ) | 283 | ) | 
| 277 | bsz = latents.shape[0] | 284 | bsz = latents.shape[0] | 
| 278 | # Sample a random timestep for each image | 285 | # Sample a random timestep for each image | 
