summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/training/functional.py b/training/functional.py
index 87bb339..d285366 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -274,7 +274,7 @@ def loss_step(
274 noise_scheduler: SchedulerMixin, 274 noise_scheduler: SchedulerMixin,
275 unet: UNet2DConditionModel, 275 unet: UNet2DConditionModel,
276 text_encoder: CLIPTextModel, 276 text_encoder: CLIPTextModel,
277 with_prior_preservation: bool, 277 guidance_scale: float,
278 prior_loss_weight: float, 278 prior_loss_weight: float,
279 seed: int, 279 seed: int,
280 offset_noise_strength: float, 280 offset_noise_strength: float,
@@ -283,13 +283,13 @@ def loss_step(
283 eval: bool = False, 283 eval: bool = False,
284 min_snr_gamma: int = 5, 284 min_snr_gamma: int = 5,
285): 285):
286 # Convert images to latent space 286 images = batch["pixel_values"]
287 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 287 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None
288 latents = latents * vae.config.scaling_factor 288 bsz = images.shape[0]
289
290 bsz = latents.shape[0]
291 289
292 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 290 # Convert images to latent space
291 latents = vae.encode(images).latent_dist.sample(generator=generator)
292 latents *= vae.config.scaling_factor
293 293
294 # Sample noise that we'll add to the latents 294 # Sample noise that we'll add to the latents
295 noise = torch.randn( 295 noise = torch.randn(
@@ -301,13 +301,13 @@ def loss_step(
301 ) 301 )
302 302
303 if offset_noise_strength != 0: 303 if offset_noise_strength != 0:
304 noise += offset_noise_strength * perlin_noise( 304 offset_noise = torch.randn(
305 latents.shape, 305 (latents.shape[0], latents.shape[1], 1, 1),
306 res=1,
307 dtype=latents.dtype, 306 dtype=latents.dtype,
308 device=latents.device, 307 device=latents.device,
309 generator=generator 308 generator=generator
310 ) 309 ).expand(noise.shape)
310 noise += offset_noise_strength * offset_noise
311 311
312 # Sample a random timestep for each image 312 # Sample a random timestep for each image
313 timesteps = torch.randint( 313 timesteps = torch.randint(
@@ -343,7 +343,13 @@ def loss_step(
343 else: 343 else:
344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
345 345
346 if with_prior_preservation: 346 if guidance_scale != 0:
347 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
348 model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0)
349 model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond)
350
351 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
352 elif prior_loss_weight != 0:
347 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 353 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
348 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 354 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
349 target, target_prior = torch.chunk(target, 2, dim=0) 355 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -607,9 +613,9 @@ def train(
607 checkpoint_frequency: int = 50, 613 checkpoint_frequency: int = 50,
608 milestone_checkpoints: bool = True, 614 milestone_checkpoints: bool = True,
609 global_step_offset: int = 0, 615 global_step_offset: int = 0,
610 with_prior_preservation: bool = False, 616 guidance_scale: float = 0.0,
611 prior_loss_weight: float = 1.0, 617 prior_loss_weight: float = 1.0,
612 offset_noise_strength: float = 0.1, 618 offset_noise_strength: float = 0.15,
613 **kwargs, 619 **kwargs,
614): 620):
615 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 621 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -638,7 +644,7 @@ def train(
638 noise_scheduler, 644 noise_scheduler,
639 unet, 645 unet,
640 text_encoder, 646 text_encoder,
641 with_prior_preservation, 647 guidance_scale,
642 prior_loss_weight, 648 prior_loss_weight,
643 seed, 649 seed,
644 offset_noise_strength, 650 offset_noise_strength,