From 0767c7bc82645186159965c2a6be4278e33c6721 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 11:07:57 +0100 Subject: Update --- data/csv.py | 2 +- models/clip/util.py | 6 +++--- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 +++- train_dreambooth.py | 11 +++++++---- train_lora.py | 3 --- train_ti.py | 2 +- training/functional.py | 7 ++----- training/strategy/dreambooth.py | 8 ++++---- training/strategy/lora.py | 12 ++++++------ training/strategy/ti.py | 4 ++-- 10 files changed, 29 insertions(+), 30 deletions(-) diff --git a/data/csv.py b/data/csv.py index 619452e..fba5d4b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -282,7 +282,7 @@ class VlpnDataModule(): collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: - data_train, data_val = items, items[:self.batch_size] + data_train, data_val = items, items else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) diff --git a/models/clip/util.py b/models/clip/util.py index 8de8c19..883de6a 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -23,11 +23,11 @@ def get_extended_embeddings( model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] - input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) + input_ids = input_ids.view((-1, model_max_length)) if position_ids is not None: - position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) + position_ids = position_ids.view((-1, model_max_length)) if attention_mask is not None: - attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) + attention_mask = attention_mask.view((-1, model_max_length)) text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4505a2a..dbd262f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -291,7 +291,8 @@ class VlpnStableDiffusion(DiffusionPipeline): else: attention_mask = None - prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) + prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) return prompt_embeds @@ -374,6 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents + # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/train_dreambooth.py b/train_dreambooth.py index f8f6e84..a85ae4c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -582,12 +582,15 @@ def main(): ) datamodule.setup() - optimizer = create_optimizer( - itertools.chain( - unet.parameters(), + params_to_optimize = (unet.parameters(), ) + if args.train_text_encoder_epochs != 0: + params_to_optimize += ( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), - ), + ) + + optimizer = create_optimizer( + itertools.chain(*params_to_optimize), lr=args.learning_rate, ) diff --git a/train_lora.py b/train_lora.py index 787f271..8dd3c86 100644 --- a/train_lora.py +++ b/train_lora.py @@ -424,9 +424,6 @@ def main(): tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - tokenizer.set_dropout(args.vector_dropout) - vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() diff --git a/train_ti.py b/train_ti.py index 7aeff7c..9bc74c1 100644 --- a/train_ti.py +++ b/train_ti.py @@ -340,7 +340,7 @@ def parse_args(): parser.add_argument( "--checkpoint_frequency", type=int, - default=5, + default=999999, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( diff --git a/training/functional.py b/training/functional.py index ebb48ab..015fe5e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -259,7 +259,7 @@ def snr_weight(noisy_latents, latents, gamma): sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) snr = torch.div(alpha_mean_sq, sigma_mean_sq) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() + snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() return snr_weight return torch.tensor( @@ -471,10 +471,7 @@ def train_loop( "lr": lr_scheduler.get_last_lr()[0], } if isDadaptation: - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] logs.update(on_log()) local_progress_bar.set_postfix(**logs) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e5e84c8..28fccff 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -137,8 +137,8 @@ def dreambooth_strategy_callbacks( print("Saving model...") - unet_ = accelerator.unwrap_model(unet, False) - text_encoder_ = accelerator.unwrap_model(text_encoder, False) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) with ema_context(): pipeline = VlpnStableDiffusion( @@ -160,8 +160,8 @@ def dreambooth_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - unet_ = accelerator.unwrap_model(unet, False) - text_encoder_ = accelerator.unwrap_model(text_encoder, False) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype diff --git a/training/strategy/lora.py b/training/strategy/lora.py index aa75bec..1c8fad6 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -47,7 +47,6 @@ def lora_strategy_callbacks( save_samples_ = partial( save_samples, accelerator=accelerator, - text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, sample_scheduler=sample_scheduler, @@ -72,6 +71,7 @@ def lora_strategy_callbacks( @contextmanager def on_train(epoch: int): tokenizer.train() + text_encoder.train() yield @contextmanager @@ -89,8 +89,8 @@ def lora_strategy_callbacks( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - unet_ = accelerator.unwrap_model(unet, False) - text_encoder_ = accelerator.unwrap_model(text_encoder, False) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) lora_config = {} state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) @@ -111,10 +111,10 @@ def lora_strategy_callbacks( @torch.no_grad() def on_sample(step): - unet_ = accelerator.unwrap_model(unet, False) - text_encoder_ = accelerator.unwrap_model(text_encoder, False) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) - save_samples_(step=step, unet=unet_) + save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) del unet_ del text_encoder_ diff --git a/training/strategy/ti.py b/training/strategy/ti.py index bd0d178..2038e34 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -156,8 +156,8 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - unet_ = accelerator.unwrap_model(unet, False) - text_encoder_ = accelerator.unwrap_model(text_encoder, False) + unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype -- cgit v1.2.3-54-g00ecf