From 7c02c2fe68da2411623f0a11c1187ccf0f7743d8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 1 Dec 2022 22:01:47 +0100 Subject: Update --- dreambooth.py | 17 +++++++++-------- environment.yaml | 2 +- infer.py | 1 + textual_inversion.py | 9 +++++---- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 31dbea2..1ead6dd 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -550,11 +550,11 @@ class Checkpointer: def main(): args = parse_args() - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) + # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + # raise ValueError( + # "Gradient accumulation is not supported when training the text encoder in distributed training. " + # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + # ) instance_identifier = args.instance_identifier @@ -899,6 +899,9 @@ def main(): ) global_progress_bar.set_description("Total progress") + index_fixed_tokens = torch.arange(len(tokenizer)) + index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -910,7 +913,7 @@ def main(): sample_checkpoint = False for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): + with accelerator.accumulate(itertools.chain(unet, text_encoder)): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -967,8 +970,6 @@ def main(): else: token_embeds = text_encoder.get_input_embeddings().weight - # Get the index for tokens that we want to freeze - index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] if accelerator.sync_gradients: diff --git a/environment.yaml b/environment.yaml index 4972ebd..24693d5 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,7 +11,7 @@ dependencies: - pytorch=1.12.1 - torchvision=0.13.1 - pandas=1.4.3 - - xformers=0.0.15.dev337 + - xformers=0.0.15.dev344 - pip: - -e . - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers diff --git a/infer.py b/infer.py index eabeb5e..75c8621 100644 --- a/infer.py +++ b/infer.py @@ -219,6 +219,7 @@ def create_pipeline(model, ti_embeddings_dir, dtype): scheduler=scheduler, ) pipeline.enable_xformers_memory_efficient_attention() + pipeline.enable_vae_slicing() pipeline.to("cuda") print("Pipeline loaded.") diff --git a/textual_inversion.py b/textual_inversion.py index d6be522..80f1d7d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -545,6 +545,7 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') + vae.enable_slicing() unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: @@ -814,6 +815,9 @@ def main(): ) global_progress_bar.set_description("Total progress") + index_fixed_tokens = torch.arange(len(tokenizer)) + index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -827,7 +831,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -883,7 +887,6 @@ def main(): token_embeds = text_encoder.get_input_embeddings().weight # Get the index for tokens that we want to freeze - index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] optimizer.step() @@ -927,8 +930,6 @@ def main(): accelerator.wait_for_everyone() - print(token_embeds[placeholder_token_id]) - text_encoder.eval() val_loss = 0.0 -- cgit v1.2.3-54-g00ecf