diff options
| -rw-r--r-- | data/csv.py | 4 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 13 | ||||
| -rw-r--r-- | train_lora.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 45 | ||||
| -rw-r--r-- | training/functional.py | 33 |
5 files changed, 71 insertions, 32 deletions
diff --git a/data/csv.py b/data/csv.py index c5e7aef..81e8b6b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -187,7 +187,7 @@ class VlpnDataModule(): | |||
| 187 | dropout: float = 0, | 187 | dropout: float = 0, |
| 188 | shuffle: bool = False, | 188 | shuffle: bool = False, |
| 189 | interpolation: str = "bicubic", | 189 | interpolation: str = "bicubic", |
| 190 | color_jitter: bool = True, | 190 | color_jitter: bool = False, |
| 191 | template_key: str = "template", | 191 | template_key: str = "template", |
| 192 | placeholder_tokens: list[str] = [], | 192 | placeholder_tokens: list[str] = [], |
| 193 | valid_set_size: Optional[int] = None, | 193 | valid_set_size: Optional[int] = None, |
| @@ -372,7 +372,7 @@ class VlpnDataset(IterableDataset): | |||
| 372 | dropout: float = 0, | 372 | dropout: float = 0, |
| 373 | shuffle: bool = False, | 373 | shuffle: bool = False, |
| 374 | interpolation: str = "bicubic", | 374 | interpolation: str = "bicubic", |
| 375 | color_jitter: bool = True, | 375 | color_jitter: bool = False, |
| 376 | generator: Optional[torch.Generator] = None, | 376 | generator: Optional[torch.Generator] = None, |
| 377 | npgenerator: Optional[np.random.Generator] = None, | 377 | npgenerator: Optional[np.random.Generator] = None, |
| 378 | ): | 378 | ): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa3dbc6..aa446ec 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -386,7 +386,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 386 | 386 | ||
| 387 | def decode_latents(self, latents): | 387 | def decode_latents(self, latents): |
| 388 | latents = 1 / self.vae.config.scaling_factor * latents | 388 | latents = 1 / self.vae.config.scaling_factor * latents |
| 389 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | 389 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] |
| 390 | image = (image / 2 + 0.5).clamp(0, 1) | 390 | image = (image / 2 + 0.5).clamp(0, 1) |
| 391 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 391 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
| 392 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 392 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| @@ -545,7 +545,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 545 | t, | 545 | t, |
| 546 | encoder_hidden_states=prompt_embeds, | 546 | encoder_hidden_states=prompt_embeds, |
| 547 | cross_attention_kwargs=cross_attention_kwargs, | 547 | cross_attention_kwargs=cross_attention_kwargs, |
| 548 | ).sample | 548 | return_dict=False, |
| 549 | )[0] | ||
| 549 | 550 | ||
| 550 | # perform guidance | 551 | # perform guidance |
| 551 | if do_classifier_free_guidance: | 552 | if do_classifier_free_guidance: |
| @@ -567,7 +568,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 567 | ) | 568 | ) |
| 568 | uncond_emb, _ = prompt_embeds.chunk(2) | 569 | uncond_emb, _ = prompt_embeds.chunk(2) |
| 569 | # forward and give guidance | 570 | # forward and give guidance |
| 570 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample | 571 | degraded_pred = self.unet( |
| 572 | degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] | ||
| 571 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 573 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
| 572 | else: | 574 | else: |
| 573 | # DDIM-like prediction of x0 | 575 | # DDIM-like prediction of x0 |
| @@ -579,11 +581,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 579 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) | 581 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) |
| 580 | ) | 582 | ) |
| 581 | # forward and give guidance | 583 | # forward and give guidance |
| 582 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample | 584 | degraded_pred = self.unet( |
| 585 | degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | ||
| 583 | noise_pred += sag_scale * (noise_pred - degraded_pred) | 586 | noise_pred += sag_scale * (noise_pred - degraded_pred) |
| 584 | 587 | ||
| 585 | # compute the previous noisy sample x_t -> x_t-1 | 588 | # compute the previous noisy sample x_t -> x_t-1 |
| 586 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 589 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
| 587 | 590 | ||
| 588 | # call the callback, if provided | 591 | # call the callback, if provided |
| 589 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 592 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
diff --git a/train_lora.py b/train_lora.py index 3c8fc97..cc7c1ec 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -251,6 +251,12 @@ def parse_args(): | |||
| 251 | help="Perlin offset noise strength.", | 251 | help="Perlin offset noise strength.", |
| 252 | ) | 252 | ) |
| 253 | parser.add_argument( | 253 | parser.add_argument( |
| 254 | "--input_pertubation", | ||
| 255 | type=float, | ||
| 256 | default=0, | ||
| 257 | help="The scale of input pretubation. Recommended 0.1." | ||
| 258 | ) | ||
| 259 | parser.add_argument( | ||
| 254 | "--num_train_epochs", | 260 | "--num_train_epochs", |
| 255 | type=int, | 261 | type=int, |
| 256 | default=None | 262 | default=None |
| @@ -1040,6 +1046,7 @@ def main(): | |||
| 1040 | checkpoint_output_dir=pti_checkpoint_output_dir, | 1046 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 1041 | sample_frequency=pti_sample_frequency, | 1047 | sample_frequency=pti_sample_frequency, |
| 1042 | offset_noise_strength=0, | 1048 | offset_noise_strength=0, |
| 1049 | input_pertubation=args.input_pertubation, | ||
| 1043 | no_val=True, | 1050 | no_val=True, |
| 1044 | ) | 1051 | ) |
| 1045 | 1052 | ||
| @@ -1195,6 +1202,7 @@ def main(): | |||
| 1195 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1202 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 1196 | sample_frequency=lora_sample_frequency, | 1203 | sample_frequency=lora_sample_frequency, |
| 1197 | offset_noise_strength=args.offset_noise_strength, | 1204 | offset_noise_strength=args.offset_noise_strength, |
| 1205 | input_pertubation=args.input_pertubation, | ||
| 1198 | no_val=args.valid_set_size == 0, | 1206 | no_val=args.valid_set_size == 0, |
| 1199 | avg_loss=avg_loss, | 1207 | avg_loss=avg_loss, |
| 1200 | avg_acc=avg_acc, | 1208 | avg_acc=avg_acc, |
diff --git a/train_ti.py b/train_ti.py index fce4a5e..ae73639 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -237,7 +237,13 @@ def parse_args(): | |||
| 237 | "--offset_noise_strength", | 237 | "--offset_noise_strength", |
| 238 | type=float, | 238 | type=float, |
| 239 | default=0, | 239 | default=0, |
| 240 | help="Perlin offset noise strength.", | 240 | help="Offset noise strength.", |
| 241 | ) | ||
| 242 | parser.add_argument( | ||
| 243 | "--input_pertubation", | ||
| 244 | type=float, | ||
| 245 | default=0, | ||
| 246 | help="The scale of input pretubation. Recommended 0.1." | ||
| 241 | ) | 247 | ) |
| 242 | parser.add_argument( | 248 | parser.add_argument( |
| 243 | "--num_train_epochs", | 249 | "--num_train_epochs", |
| @@ -407,6 +413,16 @@ def parse_args(): | |||
| 407 | ), | 413 | ), |
| 408 | ) | 414 | ) |
| 409 | parser.add_argument( | 415 | parser.add_argument( |
| 416 | "--compile_unet", | ||
| 417 | action="store_true", | ||
| 418 | help="Compile UNet with Torch Dynamo.", | ||
| 419 | ) | ||
| 420 | parser.add_argument( | ||
| 421 | "--use_xformers", | ||
| 422 | action="store_true", | ||
| 423 | help="Use xformers.", | ||
| 424 | ) | ||
| 425 | parser.add_argument( | ||
| 410 | "--checkpoint_frequency", | 426 | "--checkpoint_frequency", |
| 411 | type=int, | 427 | type=int, |
| 412 | default=999999, | 428 | default=999999, |
| @@ -671,23 +687,24 @@ def main(): | |||
| 671 | tokenizer.set_dropout(args.vector_dropout) | 687 | tokenizer.set_dropout(args.vector_dropout) |
| 672 | 688 | ||
| 673 | vae.enable_slicing() | 689 | vae.enable_slicing() |
| 674 | vae.set_use_memory_efficient_attention_xformers(True) | 690 | |
| 675 | unet.enable_xformers_memory_efficient_attention() | 691 | if args.use_xformers: |
| 676 | # unet = torch.compile(unet) | 692 | vae.set_use_memory_efficient_attention_xformers(True) |
| 693 | unet.enable_xformers_memory_efficient_attention() | ||
| 677 | 694 | ||
| 678 | if args.gradient_checkpointing: | 695 | if args.gradient_checkpointing: |
| 679 | unet.enable_gradient_checkpointing() | 696 | unet.enable_gradient_checkpointing() |
| 680 | text_encoder.gradient_checkpointing_enable() | 697 | text_encoder.gradient_checkpointing_enable() |
| 681 | 698 | ||
| 682 | convnext = create_model( | 699 | # convnext = create_model( |
| 683 | "convnext_tiny", | 700 | # "convnext_tiny", |
| 684 | pretrained=False, | 701 | # pretrained=False, |
| 685 | num_classes=3, | 702 | # num_classes=3, |
| 686 | drop_path_rate=0.0, | 703 | # drop_path_rate=0.0, |
| 687 | ) | 704 | # ) |
| 688 | convnext.to(accelerator.device, dtype=weight_dtype) | 705 | # convnext.to(accelerator.device, dtype=weight_dtype) |
| 689 | convnext.requires_grad_(False) | 706 | # convnext.requires_grad_(False) |
| 690 | convnext.eval() | 707 | # convnext.eval() |
| 691 | 708 | ||
| 692 | if len(args.alias_tokens) != 0: | 709 | if len(args.alias_tokens) != 0: |
| 693 | alias_placeholder_tokens = args.alias_tokens[::2] | 710 | alias_placeholder_tokens = args.alias_tokens[::2] |
| @@ -822,6 +839,7 @@ def main(): | |||
| 822 | noise_scheduler=noise_scheduler, | 839 | noise_scheduler=noise_scheduler, |
| 823 | dtype=weight_dtype, | 840 | dtype=weight_dtype, |
| 824 | seed=args.seed, | 841 | seed=args.seed, |
| 842 | compile_unet=args.compile_unet, | ||
| 825 | guidance_scale=args.guidance_scale, | 843 | guidance_scale=args.guidance_scale, |
| 826 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 844 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 827 | no_val=args.valid_set_size == 0, | 845 | no_val=args.valid_set_size == 0, |
| @@ -831,6 +849,7 @@ def main(): | |||
| 831 | milestone_checkpoints=not args.no_milestone_checkpoints, | 849 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 832 | global_step_offset=global_step_offset, | 850 | global_step_offset=global_step_offset, |
| 833 | offset_noise_strength=args.offset_noise_strength, | 851 | offset_noise_strength=args.offset_noise_strength, |
| 852 | input_pertubation=args.input_pertubation, | ||
| 834 | # -- | 853 | # -- |
| 835 | use_emb_decay=args.use_emb_decay, | 854 | use_emb_decay=args.use_emb_decay, |
| 836 | emb_decay_target=args.emb_decay_target, | 855 | emb_decay_target=args.emb_decay_target, |
diff --git a/training/functional.py b/training/functional.py index 38dd59f..e7e1eb3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -324,6 +324,7 @@ def loss_step( | |||
| 324 | prior_loss_weight: float, | 324 | prior_loss_weight: float, |
| 325 | seed: int, | 325 | seed: int, |
| 326 | offset_noise_strength: float, | 326 | offset_noise_strength: float, |
| 327 | input_pertubation: float, | ||
| 327 | disc: Optional[ConvNeXtDiscriminator], | 328 | disc: Optional[ConvNeXtDiscriminator], |
| 328 | min_snr_gamma: int, | 329 | min_snr_gamma: int, |
| 329 | step: int, | 330 | step: int, |
| @@ -337,7 +338,7 @@ def loss_step( | |||
| 337 | 338 | ||
| 338 | # Convert images to latent space | 339 | # Convert images to latent space |
| 339 | latents = vae.encode(images).latent_dist.sample(generator=generator) | 340 | latents = vae.encode(images).latent_dist.sample(generator=generator) |
| 340 | latents *= vae.config.scaling_factor | 341 | latents = latents * vae.config.scaling_factor |
| 341 | 342 | ||
| 342 | # Sample noise that we'll add to the latents | 343 | # Sample noise that we'll add to the latents |
| 343 | noise = torch.randn( | 344 | noise = torch.randn( |
| @@ -355,7 +356,10 @@ def loss_step( | |||
| 355 | device=latents.device, | 356 | device=latents.device, |
| 356 | generator=generator | 357 | generator=generator |
| 357 | ).expand(noise.shape) | 358 | ).expand(noise.shape) |
| 358 | noise += offset_noise_strength * offset_noise | 359 | noise = noise + offset_noise_strength * offset_noise |
| 360 | |||
| 361 | if input_pertubation != 0: | ||
| 362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | ||
| 359 | 363 | ||
| 360 | # Sample a random timestep for each image | 364 | # Sample a random timestep for each image |
| 361 | timesteps = torch.randint( | 365 | timesteps = torch.randint( |
| @@ -369,7 +373,10 @@ def loss_step( | |||
| 369 | 373 | ||
| 370 | # Add noise to the latents according to the noise magnitude at each timestep | 374 | # Add noise to the latents according to the noise magnitude at each timestep |
| 371 | # (this is the forward diffusion process) | 375 | # (this is the forward diffusion process) |
| 372 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 376 | if input_pertubation != 0: |
| 377 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) | ||
| 378 | else: | ||
| 379 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 373 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 380 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
| 374 | 381 | ||
| 375 | # Get the text embedding for conditioning | 382 | # Get the text embedding for conditioning |
| @@ -381,7 +388,7 @@ def loss_step( | |||
| 381 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 388 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
| 382 | 389 | ||
| 383 | # Predict the noise residual | 390 | # Predict the noise residual |
| 384 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 391 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
| 385 | 392 | ||
| 386 | if guidance_scale != 0: | 393 | if guidance_scale != 0: |
| 387 | uncond_encoder_hidden_states = get_extended_embeddings( | 394 | uncond_encoder_hidden_states = get_extended_embeddings( |
| @@ -391,7 +398,7 @@ def loss_step( | |||
| 391 | ) | 398 | ) |
| 392 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 399 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
| 393 | 400 | ||
| 394 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 401 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] |
| 395 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 402 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) |
| 396 | 403 | ||
| 397 | # Get the target for loss depending on the prediction type | 404 | # Get the target for loss depending on the prediction type |
| @@ -424,9 +431,9 @@ def loss_step( | |||
| 424 | 431 | ||
| 425 | if disc is not None: | 432 | if disc is not None: |
| 426 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | 433 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) |
| 427 | rec_latent /= vae.config.scaling_factor | 434 | rec_latent = rec_latent / vae.config.scaling_factor |
| 428 | rec_latent = rec_latent.to(dtype=vae.dtype) | 435 | rec_latent = rec_latent.to(dtype=vae.dtype) |
| 429 | rec = vae.decode(rec_latent).sample | 436 | rec = vae.decode(rec_latent, return_dict=False)[0] |
| 430 | loss = 1 - disc.get_score(rec) | 437 | loss = 1 - disc.get_score(rec) |
| 431 | 438 | ||
| 432 | if min_snr_gamma != 0: | 439 | if min_snr_gamma != 0: |
| @@ -434,7 +441,7 @@ def loss_step( | |||
| 434 | mse_loss_weights = ( | 441 | mse_loss_weights = ( |
| 435 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 442 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr |
| 436 | ) | 443 | ) |
| 437 | loss *= mse_loss_weights | 444 | loss = loss * mse_loss_weights |
| 438 | 445 | ||
| 439 | loss = loss.mean() | 446 | loss = loss.mean() |
| 440 | 447 | ||
| @@ -539,7 +546,7 @@ def train_loop( | |||
| 539 | with on_train(cycle): | 546 | with on_train(cycle): |
| 540 | for step, batch in enumerate(train_dataloader): | 547 | for step, batch in enumerate(train_dataloader): |
| 541 | loss, acc, bsz = loss_step(step, batch, cache) | 548 | loss, acc, bsz = loss_step(step, batch, cache) |
| 542 | loss /= gradient_accumulation_steps | 549 | loss = loss / gradient_accumulation_steps |
| 543 | 550 | ||
| 544 | accelerator.backward(loss) | 551 | accelerator.backward(loss) |
| 545 | 552 | ||
| @@ -598,7 +605,7 @@ def train_loop( | |||
| 598 | with torch.inference_mode(), on_eval(): | 605 | with torch.inference_mode(), on_eval(): |
| 599 | for step, batch in enumerate(val_dataloader): | 606 | for step, batch in enumerate(val_dataloader): |
| 600 | loss, acc, bsz = loss_step(step, batch, cache, True) | 607 | loss, acc, bsz = loss_step(step, batch, cache, True) |
| 601 | loss /= gradient_accumulation_steps | 608 | loss = loss / gradient_accumulation_steps |
| 602 | 609 | ||
| 603 | cur_loss_val.update(loss.item(), bsz) | 610 | cur_loss_val.update(loss.item(), bsz) |
| 604 | cur_acc_val.update(acc.item(), bsz) | 611 | cur_acc_val.update(acc.item(), bsz) |
| @@ -684,7 +691,8 @@ def train( | |||
| 684 | global_step_offset: int = 0, | 691 | global_step_offset: int = 0, |
| 685 | guidance_scale: float = 0.0, | 692 | guidance_scale: float = 0.0, |
| 686 | prior_loss_weight: float = 1.0, | 693 | prior_loss_weight: float = 1.0, |
| 687 | offset_noise_strength: float = 0.15, | 694 | offset_noise_strength: float = 0.01, |
| 695 | input_pertubation: float = 0.1, | ||
| 688 | disc: Optional[ConvNeXtDiscriminator] = None, | 696 | disc: Optional[ConvNeXtDiscriminator] = None, |
| 689 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
| 690 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
| @@ -704,7 +712,7 @@ def train( | |||
| 704 | 712 | ||
| 705 | if compile_unet: | 713 | if compile_unet: |
| 706 | unet = torch.compile(unet, backend='hidet') | 714 | unet = torch.compile(unet, backend='hidet') |
| 707 | # unet = torch.compile(unet) | 715 | # unet = torch.compile(unet, mode="reduce-overhead") |
| 708 | 716 | ||
| 709 | callbacks = strategy.callbacks( | 717 | callbacks = strategy.callbacks( |
| 710 | accelerator=accelerator, | 718 | accelerator=accelerator, |
| @@ -727,6 +735,7 @@ def train( | |||
| 727 | prior_loss_weight, | 735 | prior_loss_weight, |
| 728 | seed, | 736 | seed, |
| 729 | offset_noise_strength, | 737 | offset_noise_strength, |
| 738 | input_pertubation, | ||
| 730 | disc, | 739 | disc, |
| 731 | min_snr_gamma, | 740 | min_snr_gamma, |
| 732 | ) | 741 | ) |
