From e09aaedd0e74f2fc6e2a53f233914803c65e127c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Dec 2022 10:25:58 +0100 Subject: Training update --- train_dreambooth.py | 20 ++++---------------- train_ti.py | 10 ++++------ training/ti.py | 2 -- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index c7899a0..51e881a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -24,6 +24,7 @@ from common import load_text_embeddings from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule +from training.ti import patch_trainable_embeddings from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args from models.clip.prompt import PromptProcessor @@ -567,15 +568,8 @@ def main(): print(f"Training entire text encoder.") else: print(f"Training added text embeddings") - - freeze_params(itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - text_encoder.text_model.embeddings.position_embedding.parameters(), - )) - - index_fixed_tokens = torch.arange(len(tokenizer)) - index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + text_encoder.requires_grad_(False) + patch_trainable_embeddings(text_encoder, placeholder_token_id) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -603,7 +597,7 @@ def main(): if args.train_text_encoder: text_encoder_params_to_optimize = text_encoder.parameters() else: - text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() + text_encoder_params_to_optimize = text_encoder.text_model.embeddings.trainable_embedding.parameters() # Initialize the optimizer optimizer = optimizer_class( @@ -914,12 +908,6 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - if not args.train_text_encoder: - # Let's make sure we don't update any embedding weights besides the newly added token - with torch.no_grad(): - text_encoder.get_input_embeddings( - ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] - avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) diff --git a/train_ti.py b/train_ti.py index 52bd675..a12b889 100644 --- a/train_ti.py +++ b/train_ti.py @@ -368,7 +368,6 @@ class Checkpointer(CheckpointerBase): tokenizer, text_encoder, scheduler, - text_embeddings, placeholder_token, placeholder_token_id, output_dir: Path, @@ -394,7 +393,6 @@ class Checkpointer(CheckpointerBase): self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler - self.text_embeddings = text_embeddings @torch.no_grad() def checkpoint(self, step, postfix): @@ -407,7 +405,7 @@ class Checkpointer(CheckpointerBase): for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): # Save a checkpoint - learned_embeds = self.text_embeddings.trainable_embedding.weight[placeholder_token_id] + learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) @@ -517,8 +515,9 @@ def main(): vae.requires_grad_(False) unet.requires_grad_(False) + text_encoder.requires_grad_(False) - text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) + patch_trainable_embeddings(text_encoder, placeholder_token_id) prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -541,7 +540,7 @@ def main(): # Initialize the optimizer optimizer = optimizer_class( - text_embeddings.trainable_embedding.parameters(), # only optimize the embeddings + text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -741,7 +740,6 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, - text_embeddings=text_embeddings, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, diff --git a/training/ti.py b/training/ti.py index a5e407b..8b2fdd6 100644 --- a/training/ti.py +++ b/training/ti.py @@ -18,8 +18,6 @@ def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): text_encoder.text_model.embeddings = text_embeddings - return text_embeddings - class TrainableEmbeddings(CLIPTextEmbeddings): def __init__(self, config: CLIPTextConfig, new_ids: list[int]): -- cgit v1.2.3-70-g09d2