From 8d2aa65402c829583e26cdf2c336b8d3057657d6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 5 May 2023 10:51:14 +0200 Subject: Update --- data/csv.py | 4 +- .../stable_diffusion/vlpn_stable_diffusion.py | 13 ++++--- train_lora.py | 8 ++++ train_ti.py | 45 +++++++++++++++------- training/functional.py | 33 ++++++++++------ 5 files changed, 71 insertions(+), 32 deletions(-) diff --git a/data/csv.py b/data/csv.py index c5e7aef..81e8b6b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -187,7 +187,7 @@ class VlpnDataModule(): dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", - color_jitter: bool = True, + color_jitter: bool = False, template_key: str = "template", placeholder_tokens: list[str] = [], valid_set_size: Optional[int] = None, @@ -372,7 +372,7 @@ class VlpnDataset(IterableDataset): dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", - color_jitter: bool = True, + color_jitter: bool = False, generator: Optional[torch.Generator] = None, npgenerator: Optional[np.random.Generator] = None, ): diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa3dbc6..aa446ec 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -386,7 +386,7 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -545,7 +545,8 @@ class VlpnStableDiffusion(DiffusionPipeline): t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -567,7 +568,8 @@ class VlpnStableDiffusion(DiffusionPipeline): ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample + degraded_pred = self.unet( + degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 @@ -579,11 +581,12 @@ class VlpnStableDiffusion(DiffusionPipeline): pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample + degraded_pred = self.unet( + degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/train_lora.py b/train_lora.py index 3c8fc97..cc7c1ec 100644 --- a/train_lora.py +++ b/train_lora.py @@ -250,6 +250,12 @@ def parse_args(): default=0, help="Perlin offset noise strength.", ) + parser.add_argument( + "--input_pertubation", + type=float, + default=0, + help="The scale of input pretubation. Recommended 0.1." + ) parser.add_argument( "--num_train_epochs", type=int, @@ -1040,6 +1046,7 @@ def main(): checkpoint_output_dir=pti_checkpoint_output_dir, sample_frequency=pti_sample_frequency, offset_noise_strength=0, + input_pertubation=args.input_pertubation, no_val=True, ) @@ -1195,6 +1202,7 @@ def main(): checkpoint_output_dir=lora_checkpoint_output_dir, sample_frequency=lora_sample_frequency, offset_noise_strength=args.offset_noise_strength, + input_pertubation=args.input_pertubation, no_val=args.valid_set_size == 0, avg_loss=avg_loss, avg_acc=avg_acc, diff --git a/train_ti.py b/train_ti.py index fce4a5e..ae73639 100644 --- a/train_ti.py +++ b/train_ti.py @@ -237,7 +237,13 @@ def parse_args(): "--offset_noise_strength", type=float, default=0, - help="Perlin offset noise strength.", + help="Offset noise strength.", + ) + parser.add_argument( + "--input_pertubation", + type=float, + default=0, + help="The scale of input pretubation. Recommended 0.1." ) parser.add_argument( "--num_train_epochs", @@ -406,6 +412,16 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--compile_unet", + action="store_true", + help="Compile UNet with Torch Dynamo.", + ) + parser.add_argument( + "--use_xformers", + action="store_true", + help="Use xformers.", + ) parser.add_argument( "--checkpoint_frequency", type=int, @@ -671,23 +687,24 @@ def main(): tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() - # unet = torch.compile(unet) + + if args.use_xformers: + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - convnext = create_model( - "convnext_tiny", - pretrained=False, - num_classes=3, - drop_path_rate=0.0, - ) - convnext.to(accelerator.device, dtype=weight_dtype) - convnext.requires_grad_(False) - convnext.eval() + # convnext = create_model( + # "convnext_tiny", + # pretrained=False, + # num_classes=3, + # drop_path_rate=0.0, + # ) + # convnext.to(accelerator.device, dtype=weight_dtype) + # convnext.requires_grad_(False) + # convnext.eval() if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] @@ -822,6 +839,7 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, + compile_unet=args.compile_unet, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, @@ -831,6 +849,7 @@ def main(): milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, offset_noise_strength=args.offset_noise_strength, + input_pertubation=args.input_pertubation, # -- use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, diff --git a/training/functional.py b/training/functional.py index 38dd59f..e7e1eb3 100644 --- a/training/functional.py +++ b/training/functional.py @@ -324,6 +324,7 @@ def loss_step( prior_loss_weight: float, seed: int, offset_noise_strength: float, + input_pertubation: float, disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, step: int, @@ -337,7 +338,7 @@ def loss_step( # Convert images to latent space latents = vae.encode(images).latent_dist.sample(generator=generator) - latents *= vae.config.scaling_factor + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn( @@ -355,7 +356,10 @@ def loss_step( device=latents.device, generator=generator ).expand(noise.shape) - noise += offset_noise_strength * offset_noise + noise = noise + offset_noise_strength * offset_noise + + if input_pertubation != 0: + new_noise = noise + input_pertubation * torch.randn_like(noise) # Sample a random timestep for each image timesteps = torch.randint( @@ -369,7 +373,10 @@ def loss_step( # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if input_pertubation != 0: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning @@ -381,7 +388,7 @@ def loss_step( encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( @@ -391,7 +398,7 @@ def loss_step( ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample + model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) # Get the target for loss depending on the prediction type @@ -424,9 +431,9 @@ def loss_step( if disc is not None: rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) - rec_latent /= vae.config.scaling_factor + rec_latent = rec_latent / vae.config.scaling_factor rec_latent = rec_latent.to(dtype=vae.dtype) - rec = vae.decode(rec_latent).sample + rec = vae.decode(rec_latent, return_dict=False)[0] loss = 1 - disc.get_score(rec) if min_snr_gamma != 0: @@ -434,7 +441,7 @@ def loss_step( mse_loss_weights = ( torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - loss *= mse_loss_weights + loss = loss * mse_loss_weights loss = loss.mean() @@ -539,7 +546,7 @@ def train_loop( with on_train(cycle): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) - loss /= gradient_accumulation_steps + loss = loss / gradient_accumulation_steps accelerator.backward(loss) @@ -598,7 +605,7 @@ def train_loop( with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) - loss /= gradient_accumulation_steps + loss = loss / gradient_accumulation_steps cur_loss_val.update(loss.item(), bsz) cur_acc_val.update(acc.item(), bsz) @@ -684,7 +691,8 @@ def train( global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, - offset_noise_strength: float = 0.15, + offset_noise_strength: float = 0.01, + input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, avg_loss: AverageMeter = AverageMeter(), @@ -704,7 +712,7 @@ def train( if compile_unet: unet = torch.compile(unet, backend='hidet') - # unet = torch.compile(unet) + # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, @@ -727,6 +735,7 @@ def train( prior_loss_weight, seed, offset_noise_strength, + input_pertubation, disc, min_snr_gamma, ) -- cgit v1.2.3-54-g00ecf