summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py4
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py13
-rw-r--r--train_lora.py8
-rw-r--r--train_ti.py45
-rw-r--r--training/functional.py33
5 files changed, 71 insertions, 32 deletions
diff --git a/data/csv.py b/data/csv.py
index c5e7aef..81e8b6b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -187,7 +187,7 @@ class VlpnDataModule():
187 dropout: float = 0, 187 dropout: float = 0,
188 shuffle: bool = False, 188 shuffle: bool = False,
189 interpolation: str = "bicubic", 189 interpolation: str = "bicubic",
190 color_jitter: bool = True, 190 color_jitter: bool = False,
191 template_key: str = "template", 191 template_key: str = "template",
192 placeholder_tokens: list[str] = [], 192 placeholder_tokens: list[str] = [],
193 valid_set_size: Optional[int] = None, 193 valid_set_size: Optional[int] = None,
@@ -372,7 +372,7 @@ class VlpnDataset(IterableDataset):
372 dropout: float = 0, 372 dropout: float = 0,
373 shuffle: bool = False, 373 shuffle: bool = False,
374 interpolation: str = "bicubic", 374 interpolation: str = "bicubic",
375 color_jitter: bool = True, 375 color_jitter: bool = False,
376 generator: Optional[torch.Generator] = None, 376 generator: Optional[torch.Generator] = None,
377 npgenerator: Optional[np.random.Generator] = None, 377 npgenerator: Optional[np.random.Generator] = None,
378 ): 378 ):
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index aa3dbc6..aa446ec 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -386,7 +386,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
386 386
387 def decode_latents(self, latents): 387 def decode_latents(self, latents):
388 latents = 1 / self.vae.config.scaling_factor * latents 388 latents = 1 / self.vae.config.scaling_factor * latents
389 image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample 389 image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
390 image = (image / 2 + 0.5).clamp(0, 1) 390 image = (image / 2 + 0.5).clamp(0, 1)
391 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 391 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
392 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 392 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -545,7 +545,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
545 t, 545 t,
546 encoder_hidden_states=prompt_embeds, 546 encoder_hidden_states=prompt_embeds,
547 cross_attention_kwargs=cross_attention_kwargs, 547 cross_attention_kwargs=cross_attention_kwargs,
548 ).sample 548 return_dict=False,
549 )[0]
549 550
550 # perform guidance 551 # perform guidance
551 if do_classifier_free_guidance: 552 if do_classifier_free_guidance:
@@ -567,7 +568,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
567 ) 568 )
568 uncond_emb, _ = prompt_embeds.chunk(2) 569 uncond_emb, _ = prompt_embeds.chunk(2)
569 # forward and give guidance 570 # forward and give guidance
570 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample 571 degraded_pred = self.unet(
572 degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0]
571 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) 573 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
572 else: 574 else:
573 # DDIM-like prediction of x0 575 # DDIM-like prediction of x0
@@ -579,11 +581,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
579 pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) 581 pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t)
580 ) 582 )
581 # forward and give guidance 583 # forward and give guidance
582 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample 584 degraded_pred = self.unet(
585 degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0]
583 noise_pred += sag_scale * (noise_pred - degraded_pred) 586 noise_pred += sag_scale * (noise_pred - degraded_pred)
584 587
585 # compute the previous noisy sample x_t -> x_t-1 588 # compute the previous noisy sample x_t -> x_t-1
586 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 589 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
587 590
588 # call the callback, if provided 591 # call the callback, if provided
589 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 592 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/train_lora.py b/train_lora.py
index 3c8fc97..cc7c1ec 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -251,6 +251,12 @@ def parse_args():
251 help="Perlin offset noise strength.", 251 help="Perlin offset noise strength.",
252 ) 252 )
253 parser.add_argument( 253 parser.add_argument(
254 "--input_pertubation",
255 type=float,
256 default=0,
257 help="The scale of input pretubation. Recommended 0.1."
258 )
259 parser.add_argument(
254 "--num_train_epochs", 260 "--num_train_epochs",
255 type=int, 261 type=int,
256 default=None 262 default=None
@@ -1040,6 +1046,7 @@ def main():
1040 checkpoint_output_dir=pti_checkpoint_output_dir, 1046 checkpoint_output_dir=pti_checkpoint_output_dir,
1041 sample_frequency=pti_sample_frequency, 1047 sample_frequency=pti_sample_frequency,
1042 offset_noise_strength=0, 1048 offset_noise_strength=0,
1049 input_pertubation=args.input_pertubation,
1043 no_val=True, 1050 no_val=True,
1044 ) 1051 )
1045 1052
@@ -1195,6 +1202,7 @@ def main():
1195 checkpoint_output_dir=lora_checkpoint_output_dir, 1202 checkpoint_output_dir=lora_checkpoint_output_dir,
1196 sample_frequency=lora_sample_frequency, 1203 sample_frequency=lora_sample_frequency,
1197 offset_noise_strength=args.offset_noise_strength, 1204 offset_noise_strength=args.offset_noise_strength,
1205 input_pertubation=args.input_pertubation,
1198 no_val=args.valid_set_size == 0, 1206 no_val=args.valid_set_size == 0,
1199 avg_loss=avg_loss, 1207 avg_loss=avg_loss,
1200 avg_acc=avg_acc, 1208 avg_acc=avg_acc,
diff --git a/train_ti.py b/train_ti.py
index fce4a5e..ae73639 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -237,7 +237,13 @@ def parse_args():
237 "--offset_noise_strength", 237 "--offset_noise_strength",
238 type=float, 238 type=float,
239 default=0, 239 default=0,
240 help="Perlin offset noise strength.", 240 help="Offset noise strength.",
241 )
242 parser.add_argument(
243 "--input_pertubation",
244 type=float,
245 default=0,
246 help="The scale of input pretubation. Recommended 0.1."
241 ) 247 )
242 parser.add_argument( 248 parser.add_argument(
243 "--num_train_epochs", 249 "--num_train_epochs",
@@ -407,6 +413,16 @@ def parse_args():
407 ), 413 ),
408 ) 414 )
409 parser.add_argument( 415 parser.add_argument(
416 "--compile_unet",
417 action="store_true",
418 help="Compile UNet with Torch Dynamo.",
419 )
420 parser.add_argument(
421 "--use_xformers",
422 action="store_true",
423 help="Use xformers.",
424 )
425 parser.add_argument(
410 "--checkpoint_frequency", 426 "--checkpoint_frequency",
411 type=int, 427 type=int,
412 default=999999, 428 default=999999,
@@ -671,23 +687,24 @@ def main():
671 tokenizer.set_dropout(args.vector_dropout) 687 tokenizer.set_dropout(args.vector_dropout)
672 688
673 vae.enable_slicing() 689 vae.enable_slicing()
674 vae.set_use_memory_efficient_attention_xformers(True) 690
675 unet.enable_xformers_memory_efficient_attention() 691 if args.use_xformers:
676 # unet = torch.compile(unet) 692 vae.set_use_memory_efficient_attention_xformers(True)
693 unet.enable_xformers_memory_efficient_attention()
677 694
678 if args.gradient_checkpointing: 695 if args.gradient_checkpointing:
679 unet.enable_gradient_checkpointing() 696 unet.enable_gradient_checkpointing()
680 text_encoder.gradient_checkpointing_enable() 697 text_encoder.gradient_checkpointing_enable()
681 698
682 convnext = create_model( 699 # convnext = create_model(
683 "convnext_tiny", 700 # "convnext_tiny",
684 pretrained=False, 701 # pretrained=False,
685 num_classes=3, 702 # num_classes=3,
686 drop_path_rate=0.0, 703 # drop_path_rate=0.0,
687 ) 704 # )
688 convnext.to(accelerator.device, dtype=weight_dtype) 705 # convnext.to(accelerator.device, dtype=weight_dtype)
689 convnext.requires_grad_(False) 706 # convnext.requires_grad_(False)
690 convnext.eval() 707 # convnext.eval()
691 708
692 if len(args.alias_tokens) != 0: 709 if len(args.alias_tokens) != 0:
693 alias_placeholder_tokens = args.alias_tokens[::2] 710 alias_placeholder_tokens = args.alias_tokens[::2]
@@ -822,6 +839,7 @@ def main():
822 noise_scheduler=noise_scheduler, 839 noise_scheduler=noise_scheduler,
823 dtype=weight_dtype, 840 dtype=weight_dtype,
824 seed=args.seed, 841 seed=args.seed,
842 compile_unet=args.compile_unet,
825 guidance_scale=args.guidance_scale, 843 guidance_scale=args.guidance_scale,
826 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 844 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
827 no_val=args.valid_set_size == 0, 845 no_val=args.valid_set_size == 0,
@@ -831,6 +849,7 @@ def main():
831 milestone_checkpoints=not args.no_milestone_checkpoints, 849 milestone_checkpoints=not args.no_milestone_checkpoints,
832 global_step_offset=global_step_offset, 850 global_step_offset=global_step_offset,
833 offset_noise_strength=args.offset_noise_strength, 851 offset_noise_strength=args.offset_noise_strength,
852 input_pertubation=args.input_pertubation,
834 # -- 853 # --
835 use_emb_decay=args.use_emb_decay, 854 use_emb_decay=args.use_emb_decay,
836 emb_decay_target=args.emb_decay_target, 855 emb_decay_target=args.emb_decay_target,
diff --git a/training/functional.py b/training/functional.py
index 38dd59f..e7e1eb3 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -324,6 +324,7 @@ def loss_step(
324 prior_loss_weight: float, 324 prior_loss_weight: float,
325 seed: int, 325 seed: int,
326 offset_noise_strength: float, 326 offset_noise_strength: float,
327 input_pertubation: float,
327 disc: Optional[ConvNeXtDiscriminator], 328 disc: Optional[ConvNeXtDiscriminator],
328 min_snr_gamma: int, 329 min_snr_gamma: int,
329 step: int, 330 step: int,
@@ -337,7 +338,7 @@ def loss_step(
337 338
338 # Convert images to latent space 339 # Convert images to latent space
339 latents = vae.encode(images).latent_dist.sample(generator=generator) 340 latents = vae.encode(images).latent_dist.sample(generator=generator)
340 latents *= vae.config.scaling_factor 341 latents = latents * vae.config.scaling_factor
341 342
342 # Sample noise that we'll add to the latents 343 # Sample noise that we'll add to the latents
343 noise = torch.randn( 344 noise = torch.randn(
@@ -355,7 +356,10 @@ def loss_step(
355 device=latents.device, 356 device=latents.device,
356 generator=generator 357 generator=generator
357 ).expand(noise.shape) 358 ).expand(noise.shape)
358 noise += offset_noise_strength * offset_noise 359 noise = noise + offset_noise_strength * offset_noise
360
361 if input_pertubation != 0:
362 new_noise = noise + input_pertubation * torch.randn_like(noise)
359 363
360 # Sample a random timestep for each image 364 # Sample a random timestep for each image
361 timesteps = torch.randint( 365 timesteps = torch.randint(
@@ -369,7 +373,10 @@ def loss_step(
369 373
370 # Add noise to the latents according to the noise magnitude at each timestep 374 # Add noise to the latents according to the noise magnitude at each timestep
371 # (this is the forward diffusion process) 375 # (this is the forward diffusion process)
372 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 376 if input_pertubation != 0:
377 noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
378 else:
379 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
373 noisy_latents = noisy_latents.to(dtype=unet.dtype) 380 noisy_latents = noisy_latents.to(dtype=unet.dtype)
374 381
375 # Get the text embedding for conditioning 382 # Get the text embedding for conditioning
@@ -381,7 +388,7 @@ def loss_step(
381 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 388 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
382 389
383 # Predict the noise residual 390 # Predict the noise residual
384 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 391 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
385 392
386 if guidance_scale != 0: 393 if guidance_scale != 0:
387 uncond_encoder_hidden_states = get_extended_embeddings( 394 uncond_encoder_hidden_states = get_extended_embeddings(
@@ -391,7 +398,7 @@ def loss_step(
391 ) 398 )
392 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) 399 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
393 400
394 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample 401 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0]
395 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 402 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond)
396 403
397 # Get the target for loss depending on the prediction type 404 # Get the target for loss depending on the prediction type
@@ -424,9 +431,9 @@ def loss_step(
424 431
425 if disc is not None: 432 if disc is not None:
426 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) 433 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
427 rec_latent /= vae.config.scaling_factor 434 rec_latent = rec_latent / vae.config.scaling_factor
428 rec_latent = rec_latent.to(dtype=vae.dtype) 435 rec_latent = rec_latent.to(dtype=vae.dtype)
429 rec = vae.decode(rec_latent).sample 436 rec = vae.decode(rec_latent, return_dict=False)[0]
430 loss = 1 - disc.get_score(rec) 437 loss = 1 - disc.get_score(rec)
431 438
432 if min_snr_gamma != 0: 439 if min_snr_gamma != 0:
@@ -434,7 +441,7 @@ def loss_step(
434 mse_loss_weights = ( 441 mse_loss_weights = (
435 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 442 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
436 ) 443 )
437 loss *= mse_loss_weights 444 loss = loss * mse_loss_weights
438 445
439 loss = loss.mean() 446 loss = loss.mean()
440 447
@@ -539,7 +546,7 @@ def train_loop(
539 with on_train(cycle): 546 with on_train(cycle):
540 for step, batch in enumerate(train_dataloader): 547 for step, batch in enumerate(train_dataloader):
541 loss, acc, bsz = loss_step(step, batch, cache) 548 loss, acc, bsz = loss_step(step, batch, cache)
542 loss /= gradient_accumulation_steps 549 loss = loss / gradient_accumulation_steps
543 550
544 accelerator.backward(loss) 551 accelerator.backward(loss)
545 552
@@ -598,7 +605,7 @@ def train_loop(
598 with torch.inference_mode(), on_eval(): 605 with torch.inference_mode(), on_eval():
599 for step, batch in enumerate(val_dataloader): 606 for step, batch in enumerate(val_dataloader):
600 loss, acc, bsz = loss_step(step, batch, cache, True) 607 loss, acc, bsz = loss_step(step, batch, cache, True)
601 loss /= gradient_accumulation_steps 608 loss = loss / gradient_accumulation_steps
602 609
603 cur_loss_val.update(loss.item(), bsz) 610 cur_loss_val.update(loss.item(), bsz)
604 cur_acc_val.update(acc.item(), bsz) 611 cur_acc_val.update(acc.item(), bsz)
@@ -684,7 +691,8 @@ def train(
684 global_step_offset: int = 0, 691 global_step_offset: int = 0,
685 guidance_scale: float = 0.0, 692 guidance_scale: float = 0.0,
686 prior_loss_weight: float = 1.0, 693 prior_loss_weight: float = 1.0,
687 offset_noise_strength: float = 0.15, 694 offset_noise_strength: float = 0.01,
695 input_pertubation: float = 0.1,
688 disc: Optional[ConvNeXtDiscriminator] = None, 696 disc: Optional[ConvNeXtDiscriminator] = None,
689 min_snr_gamma: int = 5, 697 min_snr_gamma: int = 5,
690 avg_loss: AverageMeter = AverageMeter(), 698 avg_loss: AverageMeter = AverageMeter(),
@@ -704,7 +712,7 @@ def train(
704 712
705 if compile_unet: 713 if compile_unet:
706 unet = torch.compile(unet, backend='hidet') 714 unet = torch.compile(unet, backend='hidet')
707 # unet = torch.compile(unet) 715 # unet = torch.compile(unet, mode="reduce-overhead")
708 716
709 callbacks = strategy.callbacks( 717 callbacks = strategy.callbacks(
710 accelerator=accelerator, 718 accelerator=accelerator,
@@ -727,6 +735,7 @@ def train(
727 prior_loss_weight, 735 prior_loss_weight,
728 seed, 736 seed,
729 offset_noise_strength, 737 offset_noise_strength,
738 input_pertubation,
730 disc, 739 disc,
731 min_snr_gamma, 740 min_snr_gamma,
732 ) 741 )