From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- common.py | 38 ++++++---------- environment.yaml | 1 + infer.py | 14 +++--- models/clip/embeddings.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++ models/clip/prompt.py | 6 +-- models/clip/tokenizer.py | 64 +++++++++++++++++++++++++++ train_ti.py | 88 ++++++++++++++++++------------------- training/util.py | 60 +++++++++++++++++++++++-- 8 files changed, 299 insertions(+), 81 deletions(-) create mode 100644 models/clip/embeddings.py create mode 100644 models/clip/tokenizer.py diff --git a/common.py b/common.py index f369475..e8d3ac1 100644 --- a/common.py +++ b/common.py @@ -1,9 +1,10 @@ from pathlib import Path import json -import torch +from models.clip.embeddings import ManagedCLIPTextEmbeddings +from models.clip.tokenizer import MultiCLIPTokenizer -from transformers import CLIPTextModel, CLIPTokenizer +from safetensors import safe_open def load_config(filename): @@ -18,33 +19,20 @@ def load_config(filename): return args -def load_text_embedding(embeddings, token_id, file): - data = torch.load(file, map_location="cpu") - - assert len(data.keys()) == 1, 'embedding data has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - embeddings[token_id] = emb - - -def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): +def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): return [] - files = [file for file in embeddings_dir.iterdir() if file.is_file()] - - tokens = [file.stem for file in files] - added = tokenizer.add_tokens(tokens) - token_ids = tokenizer.convert_tokens_to_ids(tokens) - - text_encoder.resize_token_embeddings(len(tokenizer)) + filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] + tokens = [filename.stem for filename in filenames] - token_embeds = text_encoder.get_input_embeddings().weight.data + for filename in embeddings_dir.iterdir(): + if filename.is_file(): + with safe_open(filename, framework="pt", device="cpu") as file: + embed = file.get_tensor("embed") - for (token_id, file) in zip(token_ids, files): - load_text_embedding(token_embeds, token_id, file) + added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) + embeddings.add_embed(added.placeholder_id) + embeddings.add_embed(added.multi_ids, embed) return tokens diff --git a/environment.yaml b/environment.yaml index c006379..7f0e903 100644 --- a/environment.yaml +++ b/environment.yaml @@ -18,6 +18,7 @@ dependencies: - accelerate==0.15.0 - bitsandbytes==0.35.4 - python-slugify>=6.1.2 + - safetensors==0.2.7 - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.25.1 diff --git a/infer.py b/infer.py index ae0b4da..4bcaff5 100644 --- a/infer.py +++ b/infer.py @@ -8,6 +8,7 @@ from pathlib import Path import torch import json from PIL import Image +from slugify import slugify from diffusers import ( AutoencoderKL, UNet2DConditionModel, @@ -20,11 +21,12 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler ) -from transformers import CLIPTextModel, CLIPTokenizer -from slugify import slugify +from transformers import CLIPTextModel +from models.clip.embeddings import patch_managed_embeddings +from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from common import load_text_embeddings, load_config +from common import load_config, load_embeddings_from_dir torch.backends.cuda.matmul.allow_tf32 = True @@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}): def create_pipeline(model, embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) + tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) + embeddings = patch_managed_embeddings(text_encoder) + added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") pipeline = VlpnStableDiffusion( diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py @@ -0,0 +1,109 @@ +from typing import Union, Optional +from pathlib import Path + +import torch +import torch.nn as nn + +from safetensors import safe_open +from safetensors.torch import save_file + +from transformers import CLIPTextModel +from transformers.models.clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPTextEmbeddings + + +def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: + old_num_embeddings, old_embedding_dim = old_embedding.weight.size() + + new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) + new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) + new_embedding.weight.data.zero_() + new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data + + return new_embedding + + +class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): + super().__init__(config) + + self.token_embedding = embeddings.token_embedding + self.position_embedding = embeddings.position_embedding + + self.temp_token_embedding = nn.Embedding( + self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) + self.temp_token_embedding.weight.data.zero_() + self.temp_token_ids = torch.tensor([]) + + def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): + if isinstance(token_ids, int): + token_ids = [token_ids] + + if initializer is not None: + if isinstance(initializer, int): + initializer = [initializer] + + if isinstance(initializer, list): + initializer = (initializer * len(token_ids))[:len(token_ids)] + + with torch.no_grad(): + initializer = self.get_embed(initializer) + + self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) + self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) + + token_ids = torch.tensor(token_ids) + + self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) + + if initializer is not None: + self.temp_token_embedding.weight.data[token_ids] = initializer + else: + self.temp_token_embedding.weight.data[token_ids].zero_() + + def load_embed(self, input_ids: list[int], filename: Path): + with safe_open(filename, framework="pt", device="cpu") as file: + self.add_embed(input_ids, file.get_tensor("embed")) + + def save_embed(self, input_ids: list[int], filename: Path): + save_file({"embed": self.get_embed(input_ids)}, filename) + + def make_permanent(self): + self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] + self.temp_token_ids = torch.tensor([]) + + def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): + if isinstance(input_ids, list): + input_ids = torch.tensor(input_ids) + + mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) + + embeds = self.token_embedding(input_ids) + embeds[mask] = self.temp_token_embedding(input_ids)[mask] + + return embeds + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.get_embed(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) + text_encoder.text_model.embeddings = text_embeddings + return text_embeddings diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Union import torch @@ -10,13 +10,13 @@ class PromptProcessor(): self.tokenizer = tokenizer self.text_encoder = text_encoder - def get_input_ids(self, prompt: Union[str, List[str]]): + def get_input_ids(self, prompt: Union[str, list[str]]): return self.tokenizer( prompt, padding="do_not_pad", ).input_ids - def unify_input_ids(self, input_ids: List[int]): + def unify_input_ids(self, input_ids: list[int]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py @@ -0,0 +1,64 @@ +import copy +from typing import NamedTuple, Union + +import numpy as np + +from transformers import CLIPTokenizer + + +class MultiCLIPTokenizerItem(NamedTuple): + token: str + placeholder_id: int + multi_ids: list[int] + + +class MultiCLIPTokenizer(CLIPTokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.token_map: dict[int, list[int]] = {} + + def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: + if isinstance(new_tokens, list): + if isinstance(num_vectors, int): + num_vectors = [num_vectors] * len(new_tokens) + + if len(num_vectors) != len(new_tokens): + raise ValueError("Expected new_tokens and num_vectors to have the same len") + + return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] + + if isinstance(num_vectors, list): + raise ValueError("Expected num_vectors to be int for single token") + + super().add_tokens(new_tokens) + + if num_vectors == 1: + multi_token = [new_tokens] + else: + multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] + super().add_tokens(multi_token) + + meta_id = super().convert_tokens_to_ids(new_tokens) + multi_ids = super().convert_tokens_to_ids(multi_token) + + self.token_map[meta_id] = multi_ids + + return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) + + def encode(self, *args, vector_shuffle=True, **kwargs): + ids = super().encode(*args, **kwargs) + new_ids = [] + + for id in ids: + if id in self.token_map: + tokens = self.token_map[id] + + if vector_shuffle: + tokens = copy.copy(tokens) + np.random.shuffle(tokens) + + new_ids = new_ids + self.token_map[id] + else: + new_ids.append(id) + + return new_ids diff --git a/train_ti.py b/train_ti.py index 088c1a6..69d15ea 100644 --- a/train_ti.py +++ b/train_ti.py @@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup import matplotlib.pyplot as plt from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel from slugify import slugify -from common import load_text_embeddings, load_text_embedding, load_config +from common import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.lr import LRFinder -from training.ti import patch_trainable_embeddings from training.util import AverageMeter, CheckpointerBase, save_args +from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor +from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -80,6 +81,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) parser.add_argument( "--num_class_images", type=int, @@ -360,8 +367,17 @@ def parse_args(): if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if args.num_vectors is None: + args.num_vectors = 1 + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.initializer_token) + if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("You must specify --placeholder_token") + raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + + if len(args.placeholder_token) != len(args.num_vectors): + raise ValueError("--placeholder_token and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase): tokenizer, text_encoder, scheduler, - placeholder_token, - placeholder_token_id, + new_tokens, output_dir: Path, sample_image_size, sample_batches, @@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase): super().__init__( datamodule=datamodule, output_dir=output_dir, - placeholder_token=placeholder_token, - placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, seed=seed or torch.random.seed(), sample_batches=sample_batches, @@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase): self.tokenizer = tokenizer self.text_encoder = text_encoder self.scheduler = scheduler + self.new_tokens = new_tokens @torch.no_grad() def checkpoint(self, step, postfix): @@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) - for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): - # Save a checkpoint - learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] - learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} - - filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + for new_token in self.new_tokens: + text_encoder.text_model.embeddings.save_embed( + new_token.multi_ids, + f"{slugify(new_token.token)}_{step}_{postfix}.bin" + ) del text_encoder del learned_embeds @@ -487,9 +499,9 @@ def main(): # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: - tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') @@ -507,45 +519,33 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + embeddings = patch_managed_embeddings(text_encoder) + if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) + + added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = torch.stack([ - torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) for token in args.initializer_token - ]) - - num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - print(f"Added {num_added_tokens} new tokens.") - - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + ] - # Resize the token embeddings as we are adding new special tokens to the tokenizer - text_encoder.resize_token_embeddings(len(tokenizer)) + new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) - # Initialise the newly added placeholder token with the embeddings of the initializer token - token_embeds = text_encoder.get_input_embeddings().weight.data + for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): + embeddings.add_embed(new_token.placeholder_id) + embeddings.add_embed(new_token.multi_ids, init_ids) - if args.resume_from is not None: - resumepath = Path(args.resume_from).joinpath("checkpoints") - - for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): - load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) - - initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): - token_embeds[token_id] = embeddings + print(f"Added {len(new_tokens)} new tokens.") vae.requires_grad_(False) unet.requires_grad_(False) - patch_trainable_embeddings(text_encoder, placeholder_token_id) - text_encoder.text_model.encoder.requires_grad_(False) text_encoder.text_model.final_layer_norm.requires_grad_(False) text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) @@ -575,7 +575,7 @@ def main(): # Initialize the optimizer optimizer = optimizer_class( - text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -816,6 +816,7 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) + config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) if config["collection"] is not None: config["collection"] = " ".join(config["collection"]) if config["exclude_collections"] is not None: @@ -852,8 +853,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, scheduler=checkpoint_scheduler, - placeholder_token=args.placeholder_token, - placeholder_token_id=placeholder_token_id, + new_tokens=new_tokens, output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, diff --git a/training/util.py b/training/util.py index d0f7fcd..43a55e1 100644 --- a/training/util.py +++ b/training/util.py @@ -1,5 +1,6 @@ from pathlib import Path import json +from typing import Iterable import torch from PIL import Image @@ -39,8 +40,6 @@ class CheckpointerBase: self, datamodule, output_dir: Path, - placeholder_token, - placeholder_token_id, sample_image_size, sample_batches, sample_batch_size, @@ -48,8 +47,6 @@ class CheckpointerBase: ): self.datamodule = datamodule self.output_dir = output_dir - self.placeholder_token = placeholder_token - self.placeholder_token_id = placeholder_token_id self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches @@ -117,3 +114,58 @@ class CheckpointerBase: del image_grid del generator + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + self.decay = decay + self.optimization_step = 0 + + @torch.no_grad() + def step(self, parameters): + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + value = (1 + self.optimization_step) / (10 + self.optimization_step) + one_minus_decay = 1 - min(self.decay, value) + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] -- cgit v1.2.3-70-g09d2