diff options
-rw-r--r-- | data/csv.py | 4 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 13 | ||||
-rw-r--r-- | train_lora.py | 8 | ||||
-rw-r--r-- | train_ti.py | 45 | ||||
-rw-r--r-- | training/functional.py | 33 |
5 files changed, 71 insertions, 32 deletions
diff --git a/data/csv.py b/data/csv.py index c5e7aef..81e8b6b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -187,7 +187,7 @@ class VlpnDataModule(): | |||
187 | dropout: float = 0, | 187 | dropout: float = 0, |
188 | shuffle: bool = False, | 188 | shuffle: bool = False, |
189 | interpolation: str = "bicubic", | 189 | interpolation: str = "bicubic", |
190 | color_jitter: bool = True, | 190 | color_jitter: bool = False, |
191 | template_key: str = "template", | 191 | template_key: str = "template", |
192 | placeholder_tokens: list[str] = [], | 192 | placeholder_tokens: list[str] = [], |
193 | valid_set_size: Optional[int] = None, | 193 | valid_set_size: Optional[int] = None, |
@@ -372,7 +372,7 @@ class VlpnDataset(IterableDataset): | |||
372 | dropout: float = 0, | 372 | dropout: float = 0, |
373 | shuffle: bool = False, | 373 | shuffle: bool = False, |
374 | interpolation: str = "bicubic", | 374 | interpolation: str = "bicubic", |
375 | color_jitter: bool = True, | 375 | color_jitter: bool = False, |
376 | generator: Optional[torch.Generator] = None, | 376 | generator: Optional[torch.Generator] = None, |
377 | npgenerator: Optional[np.random.Generator] = None, | 377 | npgenerator: Optional[np.random.Generator] = None, |
378 | ): | 378 | ): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa3dbc6..aa446ec 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -386,7 +386,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
386 | 386 | ||
387 | def decode_latents(self, latents): | 387 | def decode_latents(self, latents): |
388 | latents = 1 / self.vae.config.scaling_factor * latents | 388 | latents = 1 / self.vae.config.scaling_factor * latents |
389 | image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | 389 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] |
390 | image = (image / 2 + 0.5).clamp(0, 1) | 390 | image = (image / 2 + 0.5).clamp(0, 1) |
391 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 391 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
392 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 392 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
@@ -545,7 +545,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
545 | t, | 545 | t, |
546 | encoder_hidden_states=prompt_embeds, | 546 | encoder_hidden_states=prompt_embeds, |
547 | cross_attention_kwargs=cross_attention_kwargs, | 547 | cross_attention_kwargs=cross_attention_kwargs, |
548 | ).sample | 548 | return_dict=False, |
549 | )[0] | ||
549 | 550 | ||
550 | # perform guidance | 551 | # perform guidance |
551 | if do_classifier_free_guidance: | 552 | if do_classifier_free_guidance: |
@@ -567,7 +568,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
567 | ) | 568 | ) |
568 | uncond_emb, _ = prompt_embeds.chunk(2) | 569 | uncond_emb, _ = prompt_embeds.chunk(2) |
569 | # forward and give guidance | 570 | # forward and give guidance |
570 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample | 571 | degraded_pred = self.unet( |
572 | degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] | ||
571 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 573 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
572 | else: | 574 | else: |
573 | # DDIM-like prediction of x0 | 575 | # DDIM-like prediction of x0 |
@@ -579,11 +581,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
579 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) | 581 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) |
580 | ) | 582 | ) |
581 | # forward and give guidance | 583 | # forward and give guidance |
582 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample | 584 | degraded_pred = self.unet( |
585 | degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | ||
583 | noise_pred += sag_scale * (noise_pred - degraded_pred) | 586 | noise_pred += sag_scale * (noise_pred - degraded_pred) |
584 | 587 | ||
585 | # compute the previous noisy sample x_t -> x_t-1 | 588 | # compute the previous noisy sample x_t -> x_t-1 |
586 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 589 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
587 | 590 | ||
588 | # call the callback, if provided | 591 | # call the callback, if provided |
589 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 592 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
diff --git a/train_lora.py b/train_lora.py index 3c8fc97..cc7c1ec 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -251,6 +251,12 @@ def parse_args(): | |||
251 | help="Perlin offset noise strength.", | 251 | help="Perlin offset noise strength.", |
252 | ) | 252 | ) |
253 | parser.add_argument( | 253 | parser.add_argument( |
254 | "--input_pertubation", | ||
255 | type=float, | ||
256 | default=0, | ||
257 | help="The scale of input pretubation. Recommended 0.1." | ||
258 | ) | ||
259 | parser.add_argument( | ||
254 | "--num_train_epochs", | 260 | "--num_train_epochs", |
255 | type=int, | 261 | type=int, |
256 | default=None | 262 | default=None |
@@ -1040,6 +1046,7 @@ def main(): | |||
1040 | checkpoint_output_dir=pti_checkpoint_output_dir, | 1046 | checkpoint_output_dir=pti_checkpoint_output_dir, |
1041 | sample_frequency=pti_sample_frequency, | 1047 | sample_frequency=pti_sample_frequency, |
1042 | offset_noise_strength=0, | 1048 | offset_noise_strength=0, |
1049 | input_pertubation=args.input_pertubation, | ||
1043 | no_val=True, | 1050 | no_val=True, |
1044 | ) | 1051 | ) |
1045 | 1052 | ||
@@ -1195,6 +1202,7 @@ def main(): | |||
1195 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1202 | checkpoint_output_dir=lora_checkpoint_output_dir, |
1196 | sample_frequency=lora_sample_frequency, | 1203 | sample_frequency=lora_sample_frequency, |
1197 | offset_noise_strength=args.offset_noise_strength, | 1204 | offset_noise_strength=args.offset_noise_strength, |
1205 | input_pertubation=args.input_pertubation, | ||
1198 | no_val=args.valid_set_size == 0, | 1206 | no_val=args.valid_set_size == 0, |
1199 | avg_loss=avg_loss, | 1207 | avg_loss=avg_loss, |
1200 | avg_acc=avg_acc, | 1208 | avg_acc=avg_acc, |
diff --git a/train_ti.py b/train_ti.py index fce4a5e..ae73639 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -237,7 +237,13 @@ def parse_args(): | |||
237 | "--offset_noise_strength", | 237 | "--offset_noise_strength", |
238 | type=float, | 238 | type=float, |
239 | default=0, | 239 | default=0, |
240 | help="Perlin offset noise strength.", | 240 | help="Offset noise strength.", |
241 | ) | ||
242 | parser.add_argument( | ||
243 | "--input_pertubation", | ||
244 | type=float, | ||
245 | default=0, | ||
246 | help="The scale of input pretubation. Recommended 0.1." | ||
241 | ) | 247 | ) |
242 | parser.add_argument( | 248 | parser.add_argument( |
243 | "--num_train_epochs", | 249 | "--num_train_epochs", |
@@ -407,6 +413,16 @@ def parse_args(): | |||
407 | ), | 413 | ), |
408 | ) | 414 | ) |
409 | parser.add_argument( | 415 | parser.add_argument( |
416 | "--compile_unet", | ||
417 | action="store_true", | ||
418 | help="Compile UNet with Torch Dynamo.", | ||
419 | ) | ||
420 | parser.add_argument( | ||
421 | "--use_xformers", | ||
422 | action="store_true", | ||
423 | help="Use xformers.", | ||
424 | ) | ||
425 | parser.add_argument( | ||
410 | "--checkpoint_frequency", | 426 | "--checkpoint_frequency", |
411 | type=int, | 427 | type=int, |
412 | default=999999, | 428 | default=999999, |
@@ -671,23 +687,24 @@ def main(): | |||
671 | tokenizer.set_dropout(args.vector_dropout) | 687 | tokenizer.set_dropout(args.vector_dropout) |
672 | 688 | ||
673 | vae.enable_slicing() | 689 | vae.enable_slicing() |
674 | vae.set_use_memory_efficient_attention_xformers(True) | 690 | |
675 | unet.enable_xformers_memory_efficient_attention() | 691 | if args.use_xformers: |
676 | # unet = torch.compile(unet) | 692 | vae.set_use_memory_efficient_attention_xformers(True) |
693 | unet.enable_xformers_memory_efficient_attention() | ||
677 | 694 | ||
678 | if args.gradient_checkpointing: | 695 | if args.gradient_checkpointing: |
679 | unet.enable_gradient_checkpointing() | 696 | unet.enable_gradient_checkpointing() |
680 | text_encoder.gradient_checkpointing_enable() | 697 | text_encoder.gradient_checkpointing_enable() |
681 | 698 | ||
682 | convnext = create_model( | 699 | # convnext = create_model( |
683 | "convnext_tiny", | 700 | # "convnext_tiny", |
684 | pretrained=False, | 701 | # pretrained=False, |
685 | num_classes=3, | 702 | # num_classes=3, |
686 | drop_path_rate=0.0, | 703 | # drop_path_rate=0.0, |
687 | ) | 704 | # ) |
688 | convnext.to(accelerator.device, dtype=weight_dtype) | 705 | # convnext.to(accelerator.device, dtype=weight_dtype) |
689 | convnext.requires_grad_(False) | 706 | # convnext.requires_grad_(False) |
690 | convnext.eval() | 707 | # convnext.eval() |
691 | 708 | ||
692 | if len(args.alias_tokens) != 0: | 709 | if len(args.alias_tokens) != 0: |
693 | alias_placeholder_tokens = args.alias_tokens[::2] | 710 | alias_placeholder_tokens = args.alias_tokens[::2] |
@@ -822,6 +839,7 @@ def main(): | |||
822 | noise_scheduler=noise_scheduler, | 839 | noise_scheduler=noise_scheduler, |
823 | dtype=weight_dtype, | 840 | dtype=weight_dtype, |
824 | seed=args.seed, | 841 | seed=args.seed, |
842 | compile_unet=args.compile_unet, | ||
825 | guidance_scale=args.guidance_scale, | 843 | guidance_scale=args.guidance_scale, |
826 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 844 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
827 | no_val=args.valid_set_size == 0, | 845 | no_val=args.valid_set_size == 0, |
@@ -831,6 +849,7 @@ def main(): | |||
831 | milestone_checkpoints=not args.no_milestone_checkpoints, | 849 | milestone_checkpoints=not args.no_milestone_checkpoints, |
832 | global_step_offset=global_step_offset, | 850 | global_step_offset=global_step_offset, |
833 | offset_noise_strength=args.offset_noise_strength, | 851 | offset_noise_strength=args.offset_noise_strength, |
852 | input_pertubation=args.input_pertubation, | ||
834 | # -- | 853 | # -- |
835 | use_emb_decay=args.use_emb_decay, | 854 | use_emb_decay=args.use_emb_decay, |
836 | emb_decay_target=args.emb_decay_target, | 855 | emb_decay_target=args.emb_decay_target, |
diff --git a/training/functional.py b/training/functional.py index 38dd59f..e7e1eb3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -324,6 +324,7 @@ def loss_step( | |||
324 | prior_loss_weight: float, | 324 | prior_loss_weight: float, |
325 | seed: int, | 325 | seed: int, |
326 | offset_noise_strength: float, | 326 | offset_noise_strength: float, |
327 | input_pertubation: float, | ||
327 | disc: Optional[ConvNeXtDiscriminator], | 328 | disc: Optional[ConvNeXtDiscriminator], |
328 | min_snr_gamma: int, | 329 | min_snr_gamma: int, |
329 | step: int, | 330 | step: int, |
@@ -337,7 +338,7 @@ def loss_step( | |||
337 | 338 | ||
338 | # Convert images to latent space | 339 | # Convert images to latent space |
339 | latents = vae.encode(images).latent_dist.sample(generator=generator) | 340 | latents = vae.encode(images).latent_dist.sample(generator=generator) |
340 | latents *= vae.config.scaling_factor | 341 | latents = latents * vae.config.scaling_factor |
341 | 342 | ||
342 | # Sample noise that we'll add to the latents | 343 | # Sample noise that we'll add to the latents |
343 | noise = torch.randn( | 344 | noise = torch.randn( |
@@ -355,7 +356,10 @@ def loss_step( | |||
355 | device=latents.device, | 356 | device=latents.device, |
356 | generator=generator | 357 | generator=generator |
357 | ).expand(noise.shape) | 358 | ).expand(noise.shape) |
358 | noise += offset_noise_strength * offset_noise | 359 | noise = noise + offset_noise_strength * offset_noise |
360 | |||
361 | if input_pertubation != 0: | ||
362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | ||
359 | 363 | ||
360 | # Sample a random timestep for each image | 364 | # Sample a random timestep for each image |
361 | timesteps = torch.randint( | 365 | timesteps = torch.randint( |
@@ -369,7 +373,10 @@ def loss_step( | |||
369 | 373 | ||
370 | # Add noise to the latents according to the noise magnitude at each timestep | 374 | # Add noise to the latents according to the noise magnitude at each timestep |
371 | # (this is the forward diffusion process) | 375 | # (this is the forward diffusion process) |
372 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 376 | if input_pertubation != 0: |
377 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) | ||
378 | else: | ||
379 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
373 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 380 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
374 | 381 | ||
375 | # Get the text embedding for conditioning | 382 | # Get the text embedding for conditioning |
@@ -381,7 +388,7 @@ def loss_step( | |||
381 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 388 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
382 | 389 | ||
383 | # Predict the noise residual | 390 | # Predict the noise residual |
384 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 391 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
385 | 392 | ||
386 | if guidance_scale != 0: | 393 | if guidance_scale != 0: |
387 | uncond_encoder_hidden_states = get_extended_embeddings( | 394 | uncond_encoder_hidden_states = get_extended_embeddings( |
@@ -391,7 +398,7 @@ def loss_step( | |||
391 | ) | 398 | ) |
392 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 399 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
393 | 400 | ||
394 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 401 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] |
395 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 402 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) |
396 | 403 | ||
397 | # Get the target for loss depending on the prediction type | 404 | # Get the target for loss depending on the prediction type |
@@ -424,9 +431,9 @@ def loss_step( | |||
424 | 431 | ||
425 | if disc is not None: | 432 | if disc is not None: |
426 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | 433 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) |
427 | rec_latent /= vae.config.scaling_factor | 434 | rec_latent = rec_latent / vae.config.scaling_factor |
428 | rec_latent = rec_latent.to(dtype=vae.dtype) | 435 | rec_latent = rec_latent.to(dtype=vae.dtype) |
429 | rec = vae.decode(rec_latent).sample | 436 | rec = vae.decode(rec_latent, return_dict=False)[0] |
430 | loss = 1 - disc.get_score(rec) | 437 | loss = 1 - disc.get_score(rec) |
431 | 438 | ||
432 | if min_snr_gamma != 0: | 439 | if min_snr_gamma != 0: |
@@ -434,7 +441,7 @@ def loss_step( | |||
434 | mse_loss_weights = ( | 441 | mse_loss_weights = ( |
435 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 442 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr |
436 | ) | 443 | ) |
437 | loss *= mse_loss_weights | 444 | loss = loss * mse_loss_weights |
438 | 445 | ||
439 | loss = loss.mean() | 446 | loss = loss.mean() |
440 | 447 | ||
@@ -539,7 +546,7 @@ def train_loop( | |||
539 | with on_train(cycle): | 546 | with on_train(cycle): |
540 | for step, batch in enumerate(train_dataloader): | 547 | for step, batch in enumerate(train_dataloader): |
541 | loss, acc, bsz = loss_step(step, batch, cache) | 548 | loss, acc, bsz = loss_step(step, batch, cache) |
542 | loss /= gradient_accumulation_steps | 549 | loss = loss / gradient_accumulation_steps |
543 | 550 | ||
544 | accelerator.backward(loss) | 551 | accelerator.backward(loss) |
545 | 552 | ||
@@ -598,7 +605,7 @@ def train_loop( | |||
598 | with torch.inference_mode(), on_eval(): | 605 | with torch.inference_mode(), on_eval(): |
599 | for step, batch in enumerate(val_dataloader): | 606 | for step, batch in enumerate(val_dataloader): |
600 | loss, acc, bsz = loss_step(step, batch, cache, True) | 607 | loss, acc, bsz = loss_step(step, batch, cache, True) |
601 | loss /= gradient_accumulation_steps | 608 | loss = loss / gradient_accumulation_steps |
602 | 609 | ||
603 | cur_loss_val.update(loss.item(), bsz) | 610 | cur_loss_val.update(loss.item(), bsz) |
604 | cur_acc_val.update(acc.item(), bsz) | 611 | cur_acc_val.update(acc.item(), bsz) |
@@ -684,7 +691,8 @@ def train( | |||
684 | global_step_offset: int = 0, | 691 | global_step_offset: int = 0, |
685 | guidance_scale: float = 0.0, | 692 | guidance_scale: float = 0.0, |
686 | prior_loss_weight: float = 1.0, | 693 | prior_loss_weight: float = 1.0, |
687 | offset_noise_strength: float = 0.15, | 694 | offset_noise_strength: float = 0.01, |
695 | input_pertubation: float = 0.1, | ||
688 | disc: Optional[ConvNeXtDiscriminator] = None, | 696 | disc: Optional[ConvNeXtDiscriminator] = None, |
689 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
690 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
@@ -704,7 +712,7 @@ def train( | |||
704 | 712 | ||
705 | if compile_unet: | 713 | if compile_unet: |
706 | unet = torch.compile(unet, backend='hidet') | 714 | unet = torch.compile(unet, backend='hidet') |
707 | # unet = torch.compile(unet) | 715 | # unet = torch.compile(unet, mode="reduce-overhead") |
708 | 716 | ||
709 | callbacks = strategy.callbacks( | 717 | callbacks = strategy.callbacks( |
710 | accelerator=accelerator, | 718 | accelerator=accelerator, |
@@ -727,6 +735,7 @@ def train( | |||
727 | prior_loss_weight, | 735 | prior_loss_weight, |
728 | seed, | 736 | seed, |
729 | offset_noise_strength, | 737 | offset_noise_strength, |
738 | input_pertubation, | ||
730 | disc, | 739 | disc, |
731 | min_snr_gamma, | 740 | min_snr_gamma, |
732 | ) | 741 | ) |