From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- train_ti.py | 88 ++++++++++++++++++++++++++++++------------------------------- 1 file changed, 44 insertions(+), 44 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 088c1a6..69d15ea 100644 --- a/train_ti.py +++ b/train_ti.py @@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel from slugify import slugify -from common import load_text_embeddings, load_text_embedding, load_config +from common import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.lr import LRFinder -from training.ti import patch_trainable_embeddings from training.util import AverageMeter, CheckpointerBase, save_args +from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor +from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -80,6 +81,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) parser.add_argument( "--num_class_images", type=int, @@ -360,8 +367,17 @@ def parse_args(): if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if args.num_vectors is None: + args.num_vectors = 1 + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.initializer_token) + if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("You must specify --placeholder_token") + raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + + if len(args.placeholder_token) != len(args.num_vectors): + raise ValueError("--placeholder_token and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase): tokenizer, text_encoder, scheduler, - placeholder_token, - placeholder_token_id, + new_tokens, output_dir: Path, sample_image_size, sample_batches, @@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase): super().__init__( datamodule=datamodule, output_dir=output_dir, - placeholder_token=placeholder_token, - placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, seed=seed or torch.random.seed(), sample_batches=sample_batches, @@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase): self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler + self.new_tokens = new_tokens @torch.no_grad() def checkpoint(self, step, postfix): @@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): - # Save a checkpoint - learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] - learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} - - filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + for new_token in self.new_tokens: + text_encoder.text_model.embeddings.save_embed( + new_token.multi_ids, + f"{slugify(new_token.token)}_{step}_{postfix}.bin" + ) del text_encoder del learned_embeds @@ -487,9 +499,9 @@ def main(): # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -507,45 +519,33 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + embeddings = patch_managed_embeddings(text_encoder) + if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + + added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = torch.stack([ - torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) for token in args.initializer_token - ]) - - num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - print(f"Added {num_added_tokens} new tokens.") - - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + ] - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) + new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data + for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): + embeddings.add_embed(new_token.placeholder_id) + embeddings.add_embed(new_token.multi_ids, init_ids) - if args.resume_from is not None: - resumepath = Path(args.resume_from).joinpath("checkpoints") - - for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) - - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): - token_embeds[token_id] = embeddings + print(f"Added {len(new_tokens)} new tokens.") vae.requires_grad_(False) unet.requires_grad_(False) - patch_trainable_embeddings(text_encoder, placeholder_token_id) - text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) @@ -575,7 +575,7 @@ def main(): # Initialize the optimizer optimizer = optimizer_class( - text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -816,6 +816,7 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) + config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) if config["collection"] is not None: config["collection"] = " ".join(config["collection"]) if config["exclude_collections"] is not None: @@ -852,8 +853,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, - placeholder_token=args.placeholder_token, - placeholder_token_id=placeholder_token_id, + new_tokens=new_tokens, output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, -- cgit v1.2.3-54-g00ecf