summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py12
1 files changed, 2 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py
index 36269f0..1c38635 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -253,7 +253,6 @@ def loss_step(
253 text_encoder: CLIPTextModel, 253 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 254 with_prior_preservation: bool,
255 prior_loss_weight: float, 255 prior_loss_weight: float,
256 noise_offset: float,
257 seed: int, 256 seed: int,
258 step: int, 257 step: int,
259 batch: dict[str, Any], 258 batch: dict[str, Any],
@@ -268,17 +267,12 @@ def loss_step(
268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 267 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
269 268
270 # Sample noise that we'll add to the latents 269 # Sample noise that we'll add to the latents
271 offsets = noise_offset * torch.randn( 270 noise = torch.randn(
272 latents.shape[0], 1, 1, 1, 271 latents.shape,
273 dtype=latents.dtype, 272 dtype=latents.dtype,
274 layout=latents.layout, 273 layout=latents.layout,
275 device=latents.device, 274 device=latents.device,
276 generator=generator 275 generator=generator
277 ).expand(latents.shape)
278 noise = torch.normal(
279 mean=offsets,
280 std=1,
281 generator=generator,
282 ) 276 )
283 277
284 # Sample a random timestep for each image 278 # Sample a random timestep for each image
@@ -565,7 +559,6 @@ def train(
565 global_step_offset: int = 0, 559 global_step_offset: int = 0,
566 with_prior_preservation: bool = False, 560 with_prior_preservation: bool = False,
567 prior_loss_weight: float = 1.0, 561 prior_loss_weight: float = 1.0,
568 noise_offset: float = 0.2,
569 **kwargs, 562 **kwargs,
570): 563):
571 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 564 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -600,7 +593,6 @@ def train(
600 text_encoder, 593 text_encoder,
601 with_prior_preservation, 594 with_prior_preservation,
602 prior_loss_weight, 595 prior_loss_weight,
603 noise_offset,
604 seed, 596 seed,
605 ) 597 )
606 598