diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 83 |
1 files changed, 59 insertions, 24 deletions
diff --git a/training/functional.py b/training/functional.py index 2f7f837..be39776 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -23,6 +23,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 23 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 24 | from models.clip.util import get_extended_embeddings | 24 | from models.clip.util import get_extended_embeddings |
| 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
| 26 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
| 27 | from util.slerp import slerp | 28 | from util.slerp import slerp |
| 28 | 29 | ||
| @@ -160,7 +161,8 @@ def save_samples( | |||
| 160 | 161 | ||
| 161 | for tracker in accelerator.trackers: | 162 | for tracker in accelerator.trackers: |
| 162 | if tracker.name == "tensorboard": | 163 | if tracker.name == "tensorboard": |
| 163 | tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") | 164 | # tracker.writer.add_images(pool, all_samples, step, dataformats="NHWC") |
| 165 | pass | ||
| 164 | 166 | ||
| 165 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
| 166 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] |
| @@ -275,16 +277,38 @@ def compute_snr(timesteps, noise_scheduler): | |||
| 275 | return snr | 277 | return snr |
| 276 | 278 | ||
| 277 | 279 | ||
| 278 | def make_solid_image(color: float, shape, vae, dtype, device, generator): | 280 | def get_original( |
| 279 | img = torch.tensor( | 281 | noise_scheduler, |
| 280 | [[[[color]]]], | 282 | model_output, |
| 281 | dtype=dtype, | 283 | sample: torch.FloatTensor, |
| 282 | device=device | 284 | timesteps: torch.IntTensor |
| 283 | ).expand(1, *shape) | 285 | ): |
| 284 | img = img * 2 - 1 | 286 | alphas_cumprod = noise_scheduler.alphas_cumprod |
| 285 | img = vae.encode(img).latent_dist.sample(generator=generator) | 287 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 286 | img *= vae.config.scaling_factor | 288 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 287 | return img | 289 | |
| 290 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
| 291 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | ||
| 292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | ||
| 293 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | ||
| 294 | |||
| 295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | ||
| 296 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | ||
| 297 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
| 298 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | ||
| 299 | |||
| 300 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 301 | pred_original_sample = (sample - sigma * model_output) / alpha | ||
| 302 | elif noise_scheduler.config.prediction_type == "sample": | ||
| 303 | pred_original_sample = model_output | ||
| 304 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 305 | pred_original_sample = alpha * sample - sigma * model_output | ||
| 306 | else: | ||
| 307 | raise ValueError( | ||
| 308 | f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" | ||
| 309 | " `v_prediction` for the DDPMScheduler." | ||
| 310 | ) | ||
| 311 | return pred_original_sample | ||
| 288 | 312 | ||
| 289 | 313 | ||
| 290 | def loss_step( | 314 | def loss_step( |
| @@ -296,6 +320,7 @@ def loss_step( | |||
| 296 | prior_loss_weight: float, | 320 | prior_loss_weight: float, |
| 297 | seed: int, | 321 | seed: int, |
| 298 | offset_noise_strength: float, | 322 | offset_noise_strength: float, |
| 323 | disc: Optional[ConvNeXtDiscriminator], | ||
| 299 | min_snr_gamma: int, | 324 | min_snr_gamma: int, |
| 300 | step: int, | 325 | step: int, |
| 301 | batch: dict[str, Any], | 326 | batch: dict[str, Any], |
| @@ -373,23 +398,31 @@ def loss_step( | |||
| 373 | else: | 398 | else: |
| 374 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 399 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 375 | 400 | ||
| 376 | if guidance_scale == 0 and prior_loss_weight != 0: | 401 | if disc is None: |
| 377 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 402 | if guidance_scale == 0 and prior_loss_weight != 0: |
| 378 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 403 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 379 | target, target_prior = torch.chunk(target, 2, dim=0) | 404 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 405 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 380 | 406 | ||
| 381 | # Compute instance loss | 407 | # Compute instance loss |
| 382 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 408 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 383 | 409 | ||
| 384 | # Compute prior loss | 410 | # Compute prior loss |
| 385 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 411 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") |
| 386 | 412 | ||
| 387 | # Add the prior loss to the instance loss. | 413 | # Add the prior loss to the instance loss. |
| 388 | loss = loss + prior_loss_weight * prior_loss | 414 | loss = loss + prior_loss_weight * prior_loss |
| 389 | else: | 415 | else: |
| 390 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 416 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 391 | 417 | ||
| 392 | loss = loss.mean([1, 2, 3]) | 418 | loss = loss.mean([1, 2, 3]) |
| 419 | else: | ||
| 420 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | ||
| 421 | rec_latent /= vae.config.scaling_factor | ||
| 422 | rec_latent = rec_latent.to(dtype=vae.dtype) | ||
| 423 | rec = vae.decode(rec_latent).sample | ||
| 424 | loss = 1 - disc.get_score(rec) | ||
| 425 | del rec_latent, rec | ||
| 393 | 426 | ||
| 394 | if min_snr_gamma != 0: | 427 | if min_snr_gamma != 0: |
| 395 | snr = compute_snr(timesteps, noise_scheduler) | 428 | snr = compute_snr(timesteps, noise_scheduler) |
| @@ -645,6 +678,7 @@ def train( | |||
| 645 | guidance_scale: float = 0.0, | 678 | guidance_scale: float = 0.0, |
| 646 | prior_loss_weight: float = 1.0, | 679 | prior_loss_weight: float = 1.0, |
| 647 | offset_noise_strength: float = 0.15, | 680 | offset_noise_strength: float = 0.15, |
| 681 | disc: Optional[ConvNeXtDiscriminator] = None, | ||
| 648 | min_snr_gamma: int = 5, | 682 | min_snr_gamma: int = 5, |
| 649 | **kwargs, | 683 | **kwargs, |
| 650 | ): | 684 | ): |
| @@ -676,6 +710,7 @@ def train( | |||
| 676 | prior_loss_weight, | 710 | prior_loss_weight, |
| 677 | seed, | 711 | seed, |
| 678 | offset_noise_strength, | 712 | offset_noise_strength, |
| 713 | disc, | ||
| 679 | min_snr_gamma, | 714 | min_snr_gamma, |
| 680 | ) | 715 | ) |
| 681 | 716 | ||
