diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 36 |
1 files changed, 21 insertions, 15 deletions
diff --git a/training/functional.py b/training/functional.py index 87bb339..d285366 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -274,7 +274,7 @@ def loss_step( | |||
274 | noise_scheduler: SchedulerMixin, | 274 | noise_scheduler: SchedulerMixin, |
275 | unet: UNet2DConditionModel, | 275 | unet: UNet2DConditionModel, |
276 | text_encoder: CLIPTextModel, | 276 | text_encoder: CLIPTextModel, |
277 | with_prior_preservation: bool, | 277 | guidance_scale: float, |
278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
279 | seed: int, | 279 | seed: int, |
280 | offset_noise_strength: float, | 280 | offset_noise_strength: float, |
@@ -283,13 +283,13 @@ def loss_step( | |||
283 | eval: bool = False, | 283 | eval: bool = False, |
284 | min_snr_gamma: int = 5, | 284 | min_snr_gamma: int = 5, |
285 | ): | 285 | ): |
286 | # Convert images to latent space | 286 | images = batch["pixel_values"] |
287 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 287 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None |
288 | latents = latents * vae.config.scaling_factor | 288 | bsz = images.shape[0] |
289 | |||
290 | bsz = latents.shape[0] | ||
291 | 289 | ||
292 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 290 | # Convert images to latent space |
291 | latents = vae.encode(images).latent_dist.sample(generator=generator) | ||
292 | latents *= vae.config.scaling_factor | ||
293 | 293 | ||
294 | # Sample noise that we'll add to the latents | 294 | # Sample noise that we'll add to the latents |
295 | noise = torch.randn( | 295 | noise = torch.randn( |
@@ -301,13 +301,13 @@ def loss_step( | |||
301 | ) | 301 | ) |
302 | 302 | ||
303 | if offset_noise_strength != 0: | 303 | if offset_noise_strength != 0: |
304 | noise += offset_noise_strength * perlin_noise( | 304 | offset_noise = torch.randn( |
305 | latents.shape, | 305 | (latents.shape[0], latents.shape[1], 1, 1), |
306 | res=1, | ||
307 | dtype=latents.dtype, | 306 | dtype=latents.dtype, |
308 | device=latents.device, | 307 | device=latents.device, |
309 | generator=generator | 308 | generator=generator |
310 | ) | 309 | ).expand(noise.shape) |
310 | noise += offset_noise_strength * offset_noise | ||
311 | 311 | ||
312 | # Sample a random timestep for each image | 312 | # Sample a random timestep for each image |
313 | timesteps = torch.randint( | 313 | timesteps = torch.randint( |
@@ -343,7 +343,13 @@ def loss_step( | |||
343 | else: | 343 | else: |
344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
345 | 345 | ||
346 | if with_prior_preservation: | 346 | if guidance_scale != 0: |
347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
348 | model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) | ||
349 | model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) | ||
350 | |||
351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | ||
352 | elif prior_loss_weight != 0: | ||
347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 353 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
348 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 354 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
349 | target, target_prior = torch.chunk(target, 2, dim=0) | 355 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -607,9 +613,9 @@ def train( | |||
607 | checkpoint_frequency: int = 50, | 613 | checkpoint_frequency: int = 50, |
608 | milestone_checkpoints: bool = True, | 614 | milestone_checkpoints: bool = True, |
609 | global_step_offset: int = 0, | 615 | global_step_offset: int = 0, |
610 | with_prior_preservation: bool = False, | 616 | guidance_scale: float = 0.0, |
611 | prior_loss_weight: float = 1.0, | 617 | prior_loss_weight: float = 1.0, |
612 | offset_noise_strength: float = 0.1, | 618 | offset_noise_strength: float = 0.15, |
613 | **kwargs, | 619 | **kwargs, |
614 | ): | 620 | ): |
615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 621 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -638,7 +644,7 @@ def train( | |||
638 | noise_scheduler, | 644 | noise_scheduler, |
639 | unet, | 645 | unet, |
640 | text_encoder, | 646 | text_encoder, |
641 | with_prior_preservation, | 647 | guidance_scale, |
642 | prior_loss_weight, | 648 | prior_loss_weight, |
643 | seed, | 649 | seed, |
644 | offset_noise_strength, | 650 | offset_noise_strength, |