summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py31
1 files changed, 10 insertions, 21 deletions
diff --git a/training/functional.py b/training/functional.py
index 2d582bf..36269f0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -253,7 +253,7 @@ def loss_step(
253 text_encoder: CLIPTextModel, 253 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 254 with_prior_preservation: bool,
255 prior_loss_weight: float, 255 prior_loss_weight: float,
256 low_freq_noise: float, 256 noise_offset: float,
257 seed: int, 257 seed: int,
258 step: int, 258 step: int,
259 batch: dict[str, Any], 259 batch: dict[str, Any],
@@ -268,30 +268,19 @@ def loss_step(
268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
269 269
270 # Sample noise that we'll add to the latents 270 # Sample noise that we'll add to the latents
271 noise = torch.randn( 271 offsets = noise_offset * torch.randn(
272 latents.shape, 272 latents.shape[0], 1, 1, 1,
273 dtype=latents.dtype, 273 dtype=latents.dtype,
274 layout=latents.layout, 274 layout=latents.layout,
275 device=latents.device, 275 device=latents.device,
276 generator=generator 276 generator=generator
277 ).expand(latents.shape)
278 noise = torch.normal(
279 mean=offsets,
280 std=1,
281 generator=generator,
277 ) 282 )
278 283
279 if low_freq_noise != 0:
280 low_freq_factor = low_freq_noise * torch.randn(
281 latents.shape[0], 1, 1, 1,
282 dtype=latents.dtype,
283 layout=latents.layout,
284 device=latents.device,
285 generator=generator
286 )
287 noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn(
288 latents.shape[0], latents.shape[1], 1, 1,
289 dtype=latents.dtype,
290 layout=latents.layout,
291 device=latents.device,
292 generator=generator
293 )
294
295 # Sample a random timestep for each image 284 # Sample a random timestep for each image
296 timesteps = torch.randint( 285 timesteps = torch.randint(
297 0, 286 0,
@@ -576,7 +565,7 @@ def train(
576 global_step_offset: int = 0, 565 global_step_offset: int = 0,
577 with_prior_preservation: bool = False, 566 with_prior_preservation: bool = False,
578 prior_loss_weight: float = 1.0, 567 prior_loss_weight: float = 1.0,
579 low_freq_noise: float = 0.1, 568 noise_offset: float = 0.2,
580 **kwargs, 569 **kwargs,
581): 570):
582 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 571 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -611,7 +600,7 @@ def train(
611 text_encoder, 600 text_encoder,
612 with_prior_preservation, 601 with_prior_preservation,
613 prior_loss_weight, 602 prior_loss_weight,
614 low_freq_noise, 603 noise_offset,
615 seed, 604 seed,
616 ) 605 )
617 606