diff options
| -rw-r--r-- | common.py | 38 | ||||
| -rw-r--r-- | environment.yaml | 1 | ||||
| -rw-r--r-- | infer.py | 14 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 109 | ||||
| -rw-r--r-- | models/clip/prompt.py | 6 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 64 | ||||
| -rw-r--r-- | train_ti.py | 88 | ||||
| -rw-r--r-- | training/util.py | 60 |
8 files changed, 299 insertions, 81 deletions
| @@ -1,9 +1,10 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | 3 | ||
| 4 | import torch | 4 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
| 5 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 5 | 6 | ||
| 6 | from transformers import CLIPTextModel, CLIPTokenizer | 7 | from safetensors import safe_open |
| 7 | 8 | ||
| 8 | 9 | ||
| 9 | def load_config(filename): | 10 | def load_config(filename): |
| @@ -18,33 +19,20 @@ def load_config(filename): | |||
| 18 | return args | 19 | return args |
| 19 | 20 | ||
| 20 | 21 | ||
| 21 | def load_text_embedding(embeddings, token_id, file): | 22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): |
| 22 | data = torch.load(file, map_location="cpu") | ||
| 23 | |||
| 24 | assert len(data.keys()) == 1, 'embedding data has multiple terms in it' | ||
| 25 | |||
| 26 | emb = next(iter(data.values())) | ||
| 27 | if len(emb.shape) == 1: | ||
| 28 | emb = emb.unsqueeze(0) | ||
| 29 | |||
| 30 | embeddings[token_id] = emb | ||
| 31 | |||
| 32 | |||
| 33 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | ||
| 34 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 35 | return [] | 24 | return [] |
| 36 | 25 | ||
| 37 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] |
| 38 | 27 | tokens = [filename.stem for filename in filenames] | |
| 39 | tokens = [file.stem for file in files] | ||
| 40 | added = tokenizer.add_tokens(tokens) | ||
| 41 | token_ids = tokenizer.convert_tokens_to_ids(tokens) | ||
| 42 | |||
| 43 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 44 | 28 | ||
| 45 | token_embeds = text_encoder.get_input_embeddings().weight.data | 29 | for filename in embeddings_dir.iterdir(): |
| 30 | if filename.is_file(): | ||
| 31 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 32 | embed = file.get_tensor("embed") | ||
| 46 | 33 | ||
| 47 | for (token_id, file) in zip(token_ids, files): | 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) |
| 48 | load_text_embedding(token_embeds, token_id, file) | 35 | embeddings.add_embed(added.placeholder_id) |
| 36 | embeddings.add_embed(added.multi_ids, embed) | ||
| 49 | 37 | ||
| 50 | return tokens | 38 | return tokens |
diff --git a/environment.yaml b/environment.yaml index c006379..7f0e903 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -18,6 +18,7 @@ dependencies: | |||
| 18 | - accelerate==0.15.0 | 18 | - accelerate==0.15.0 |
| 19 | - bitsandbytes==0.35.4 | 19 | - bitsandbytes==0.35.4 |
| 20 | - python-slugify>=6.1.2 | 20 | - python-slugify>=6.1.2 |
| 21 | - safetensors==0.2.7 | ||
| 21 | - setuptools==65.6.3 | 22 | - setuptools==65.6.3 |
| 22 | - test-tube>=0.7.5 | 23 | - test-tube>=0.7.5 |
| 23 | - transformers==4.25.1 | 24 | - transformers==4.25.1 |
| @@ -8,6 +8,7 @@ from pathlib import Path | |||
| 8 | import torch | 8 | import torch |
| 9 | import json | 9 | import json |
| 10 | from PIL import Image | 10 | from PIL import Image |
| 11 | from slugify import slugify | ||
| 11 | from diffusers import ( | 12 | from diffusers import ( |
| 12 | AutoencoderKL, | 13 | AutoencoderKL, |
| 13 | UNet2DConditionModel, | 14 | UNet2DConditionModel, |
| @@ -20,11 +21,12 @@ from diffusers import ( | |||
| 20 | KDPM2DiscreteScheduler, | 21 | KDPM2DiscreteScheduler, |
| 21 | KDPM2AncestralDiscreteScheduler | 22 | KDPM2AncestralDiscreteScheduler |
| 22 | ) | 23 | ) |
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel |
| 24 | from slugify import slugify | ||
| 25 | 25 | ||
| 26 | from models.clip.embeddings import patch_managed_embeddings | ||
| 27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 28 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from common import load_text_embeddings, load_config | 29 | from common import load_config, load_embeddings_from_dir |
| 28 | 30 | ||
| 29 | 31 | ||
| 30 | torch.backends.cuda.matmul.allow_tf32 = True | 32 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}): | |||
| 183 | def create_pipeline(model, embeddings_dir, dtype): | 185 | def create_pipeline(model, embeddings_dir, dtype): |
| 184 | print("Loading Stable Diffusion pipeline...") | 186 | print("Loading Stable Diffusion pipeline...") |
| 185 | 187 | ||
| 186 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 188 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
| 187 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 189 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) |
| 188 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 190 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
| 189 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 191 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 190 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 192 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
| 191 | 193 | ||
| 192 | added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) | 194 | embeddings = patch_managed_embeddings(text_encoder) |
| 195 | added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) | ||
| 196 | |||
| 193 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 197 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") |
| 194 | 198 | ||
| 195 | pipeline = VlpnStableDiffusion( | 199 | pipeline = VlpnStableDiffusion( |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py | |||
| @@ -0,0 +1,109 @@ | |||
| 1 | from typing import Union, Optional | ||
| 2 | from pathlib import Path | ||
| 3 | |||
| 4 | import torch | ||
| 5 | import torch.nn as nn | ||
| 6 | |||
| 7 | from safetensors import safe_open | ||
| 8 | from safetensors.torch import save_file | ||
| 9 | |||
| 10 | from transformers import CLIPTextModel | ||
| 11 | from transformers.models.clip import CLIPTextConfig | ||
| 12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
| 13 | |||
| 14 | |||
| 15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | ||
| 16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | ||
| 17 | |||
| 18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | ||
| 19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | ||
| 20 | new_embedding.weight.data.zero_() | ||
| 21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | ||
| 22 | |||
| 23 | return new_embedding | ||
| 24 | |||
| 25 | |||
| 26 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | ||
| 27 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | ||
| 28 | super().__init__(config) | ||
| 29 | |||
| 30 | self.token_embedding = embeddings.token_embedding | ||
| 31 | self.position_embedding = embeddings.position_embedding | ||
| 32 | |||
| 33 | self.temp_token_embedding = nn.Embedding( | ||
| 34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
| 35 | self.temp_token_embedding.weight.data.zero_() | ||
| 36 | self.temp_token_ids = torch.tensor([]) | ||
| 37 | |||
| 38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | ||
| 39 | if isinstance(token_ids, int): | ||
| 40 | token_ids = [token_ids] | ||
| 41 | |||
| 42 | if initializer is not None: | ||
| 43 | if isinstance(initializer, int): | ||
| 44 | initializer = [initializer] | ||
| 45 | |||
| 46 | if isinstance(initializer, list): | ||
| 47 | initializer = (initializer * len(token_ids))[:len(token_ids)] | ||
| 48 | |||
| 49 | with torch.no_grad(): | ||
| 50 | initializer = self.get_embed(initializer) | ||
| 51 | |||
| 52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | ||
| 53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | ||
| 54 | |||
| 55 | token_ids = torch.tensor(token_ids) | ||
| 56 | |||
| 57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
| 58 | |||
| 59 | if initializer is not None: | ||
| 60 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
| 61 | else: | ||
| 62 | self.temp_token_embedding.weight.data[token_ids].zero_() | ||
| 63 | |||
| 64 | def load_embed(self, input_ids: list[int], filename: Path): | ||
| 65 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
| 66 | self.add_embed(input_ids, file.get_tensor("embed")) | ||
| 67 | |||
| 68 | def save_embed(self, input_ids: list[int], filename: Path): | ||
| 69 | save_file({"embed": self.get_embed(input_ids)}, filename) | ||
| 70 | |||
| 71 | def make_permanent(self): | ||
| 72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | ||
| 73 | self.temp_token_ids = torch.tensor([]) | ||
| 74 | |||
| 75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | ||
| 76 | if isinstance(input_ids, list): | ||
| 77 | input_ids = torch.tensor(input_ids) | ||
| 78 | |||
| 79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | ||
| 80 | |||
| 81 | embeds = self.token_embedding(input_ids) | ||
| 82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | ||
| 83 | |||
| 84 | return embeds | ||
| 85 | |||
| 86 | def forward( | ||
| 87 | self, | ||
| 88 | input_ids: Optional[torch.LongTensor] = None, | ||
| 89 | position_ids: Optional[torch.LongTensor] = None, | ||
| 90 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| 91 | ) -> torch.Tensor: | ||
| 92 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
| 93 | |||
| 94 | if position_ids is None: | ||
| 95 | position_ids = self.position_ids[:, :seq_length] | ||
| 96 | |||
| 97 | if inputs_embeds is None: | ||
| 98 | inputs_embeds = self.get_embed(input_ids) | ||
| 99 | |||
| 100 | position_embeddings = self.position_embedding(position_ids) | ||
| 101 | embeddings = inputs_embeds + position_embeddings | ||
| 102 | |||
| 103 | return embeddings | ||
| 104 | |||
| 105 | |||
| 106 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | ||
| 107 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | ||
| 108 | text_encoder.text_model.embeddings = text_embeddings | ||
| 109 | return text_embeddings | ||
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
| @@ -1,4 +1,4 @@ | |||
| 1 | from typing import List, Union | 1 | from typing import Union |
| 2 | 2 | ||
| 3 | import torch | 3 | import torch |
| 4 | 4 | ||
| @@ -10,13 +10,13 @@ class PromptProcessor(): | |||
| 10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
| 11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
| 12 | 12 | ||
| 13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
| 14 | return self.tokenizer( | 14 | return self.tokenizer( |
| 15 | prompt, | 15 | prompt, |
| 16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
| 17 | ).input_ids | 17 | ).input_ids |
| 18 | 18 | ||
| 19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
| 20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
| 21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
| 22 | padding=True, | 22 | padding=True, |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py | |||
| @@ -0,0 +1,64 @@ | |||
| 1 | import copy | ||
| 2 | from typing import NamedTuple, Union | ||
| 3 | |||
| 4 | import numpy as np | ||
| 5 | |||
| 6 | from transformers import CLIPTokenizer | ||
| 7 | |||
| 8 | |||
| 9 | class MultiCLIPTokenizerItem(NamedTuple): | ||
| 10 | token: str | ||
| 11 | placeholder_id: int | ||
| 12 | multi_ids: list[int] | ||
| 13 | |||
| 14 | |||
| 15 | class MultiCLIPTokenizer(CLIPTokenizer): | ||
| 16 | def __init__(self, *args, **kwargs): | ||
| 17 | super().__init__(*args, **kwargs) | ||
| 18 | self.token_map: dict[int, list[int]] = {} | ||
| 19 | |||
| 20 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | ||
| 21 | if isinstance(new_tokens, list): | ||
| 22 | if isinstance(num_vectors, int): | ||
| 23 | num_vectors = [num_vectors] * len(new_tokens) | ||
| 24 | |||
| 25 | if len(num_vectors) != len(new_tokens): | ||
| 26 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | ||
| 27 | |||
| 28 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | ||
| 29 | |||
| 30 | if isinstance(num_vectors, list): | ||
| 31 | raise ValueError("Expected num_vectors to be int for single token") | ||
| 32 | |||
| 33 | super().add_tokens(new_tokens) | ||
| 34 | |||
| 35 | if num_vectors == 1: | ||
| 36 | multi_token = [new_tokens] | ||
| 37 | else: | ||
| 38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
| 39 | super().add_tokens(multi_token) | ||
| 40 | |||
| 41 | meta_id = super().convert_tokens_to_ids(new_tokens) | ||
| 42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
| 43 | |||
| 44 | self.token_map[meta_id] = multi_ids | ||
| 45 | |||
| 46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | ||
| 47 | |||
| 48 | def encode(self, *args, vector_shuffle=True, **kwargs): | ||
| 49 | ids = super().encode(*args, **kwargs) | ||
| 50 | new_ids = [] | ||
| 51 | |||
| 52 | for id in ids: | ||
| 53 | if id in self.token_map: | ||
| 54 | tokens = self.token_map[id] | ||
| 55 | |||
| 56 | if vector_shuffle: | ||
| 57 | tokens = copy.copy(tokens) | ||
| 58 | np.random.shuffle(tokens) | ||
| 59 | |||
| 60 | new_ids = new_ids + self.token_map[id] | ||
| 61 | else: | ||
| 62 | new_ids.append(id) | ||
| 63 | |||
| 64 | return new_ids | ||
diff --git a/train_ti.py b/train_ti.py index 088c1a6..69d15ea 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, | |||
| 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 17 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
| 18 | from tqdm.auto import tqdm | 18 | from tqdm.auto import tqdm |
| 19 | from transformers import CLIPTextModel, CLIPTokenizer | 19 | from transformers import CLIPTextModel |
| 20 | from slugify import slugify | 20 | from slugify import slugify |
| 21 | 21 | ||
| 22 | from common import load_text_embeddings, load_text_embedding, load_config | 22 | from common import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import CSVDataModule, CSVDataItem | 24 | from data.csv import CSVDataModule, CSVDataItem |
| 25 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
| 26 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
| 27 | from training.ti import patch_trainable_embeddings | ||
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 27 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 28 | from models.clip.embeddings import patch_managed_embeddings | ||
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 30 | 31 | ||
| 31 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
| 32 | 33 | ||
| @@ -81,6 +82,12 @@ def parse_args(): | |||
| 81 | help="A token to use as initializer word." | 82 | help="A token to use as initializer word." |
| 82 | ) | 83 | ) |
| 83 | parser.add_argument( | 84 | parser.add_argument( |
| 85 | "--num_vectors", | ||
| 86 | type=int, | ||
| 87 | nargs='*', | ||
| 88 | help="Number of vectors per embedding." | ||
| 89 | ) | ||
| 90 | parser.add_argument( | ||
| 84 | "--num_class_images", | 91 | "--num_class_images", |
| 85 | type=int, | 92 | type=int, |
| 86 | default=1, | 93 | default=1, |
| @@ -360,8 +367,17 @@ def parse_args(): | |||
| 360 | if len(args.placeholder_token) == 0: | 367 | if len(args.placeholder_token) == 0: |
| 361 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 368 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
| 362 | 369 | ||
| 370 | if args.num_vectors is None: | ||
| 371 | args.num_vectors = 1 | ||
| 372 | |||
| 373 | if isinstance(args.num_vectors, int): | ||
| 374 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
| 375 | |||
| 363 | if len(args.placeholder_token) != len(args.initializer_token): | 376 | if len(args.placeholder_token) != len(args.initializer_token): |
| 364 | raise ValueError("You must specify --placeholder_token") | 377 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
| 378 | |||
| 379 | if len(args.placeholder_token) != len(args.num_vectors): | ||
| 380 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
| 365 | 381 | ||
| 366 | if isinstance(args.collection, str): | 382 | if isinstance(args.collection, str): |
| 367 | args.collection = [args.collection] | 383 | args.collection = [args.collection] |
| @@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase): | |||
| 386 | tokenizer, | 402 | tokenizer, |
| 387 | text_encoder, | 403 | text_encoder, |
| 388 | scheduler, | 404 | scheduler, |
| 389 | placeholder_token, | 405 | new_tokens, |
| 390 | placeholder_token_id, | ||
| 391 | output_dir: Path, | 406 | output_dir: Path, |
| 392 | sample_image_size, | 407 | sample_image_size, |
| 393 | sample_batches, | 408 | sample_batches, |
| @@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase): | |||
| 397 | super().__init__( | 412 | super().__init__( |
| 398 | datamodule=datamodule, | 413 | datamodule=datamodule, |
| 399 | output_dir=output_dir, | 414 | output_dir=output_dir, |
| 400 | placeholder_token=placeholder_token, | ||
| 401 | placeholder_token_id=placeholder_token_id, | ||
| 402 | sample_image_size=sample_image_size, | 415 | sample_image_size=sample_image_size, |
| 403 | seed=seed or torch.random.seed(), | 416 | seed=seed or torch.random.seed(), |
| 404 | sample_batches=sample_batches, | 417 | sample_batches=sample_batches, |
| @@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase): | |||
| 412 | self.tokenizer = tokenizer | 425 | self.tokenizer = tokenizer |
| 413 | self.text_encoder = text_encoder | 426 | self.text_encoder = text_encoder |
| 414 | self.scheduler = scheduler | 427 | self.scheduler = scheduler |
| 428 | self.new_tokens = new_tokens | ||
| 415 | 429 | ||
| 416 | @torch.no_grad() | 430 | @torch.no_grad() |
| 417 | def checkpoint(self, step, postfix): | 431 | def checkpoint(self, step, postfix): |
| @@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase): | |||
| 422 | 436 | ||
| 423 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 437 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 424 | 438 | ||
| 425 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 439 | for new_token in self.new_tokens: |
| 426 | # Save a checkpoint | 440 | text_encoder.text_model.embeddings.save_embed( |
| 427 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] | 441 | new_token.multi_ids, |
| 428 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 442 | f"{slugify(new_token.token)}_{step}_{postfix}.bin" |
| 429 | 443 | ) | |
| 430 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | ||
| 431 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 432 | 444 | ||
| 433 | del text_encoder | 445 | del text_encoder |
| 434 | del learned_embeds | 446 | del learned_embeds |
| @@ -487,9 +499,9 @@ def main(): | |||
| 487 | 499 | ||
| 488 | # Load the tokenizer and add the placeholder token as a additional special token | 500 | # Load the tokenizer and add the placeholder token as a additional special token |
| 489 | if args.tokenizer_name: | 501 | if args.tokenizer_name: |
| 490 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 502 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 491 | elif args.pretrained_model_name_or_path: | 503 | elif args.pretrained_model_name_or_path: |
| 492 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 504 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 493 | 505 | ||
| 494 | # Load models and create wrapper for stable diffusion | 506 | # Load models and create wrapper for stable diffusion |
| 495 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 507 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| @@ -507,45 +519,33 @@ def main(): | |||
| 507 | unet.enable_gradient_checkpointing() | 519 | unet.enable_gradient_checkpointing() |
| 508 | text_encoder.gradient_checkpointing_enable() | 520 | text_encoder.gradient_checkpointing_enable() |
| 509 | 521 | ||
| 522 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 523 | |||
| 510 | if args.embeddings_dir is not None: | 524 | if args.embeddings_dir is not None: |
| 511 | embeddings_dir = Path(args.embeddings_dir) | 525 | embeddings_dir = Path(args.embeddings_dir) |
| 512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 526 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 513 | raise ValueError("--embeddings_dir must point to an existing directory") | 527 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 514 | added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 528 | |
| 529 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
| 515 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 530 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") |
| 516 | 531 | ||
| 517 | # Convert the initializer_token, placeholder_token to ids | 532 | # Convert the initializer_token, placeholder_token to ids |
| 518 | initializer_token_ids = torch.stack([ | 533 | initializer_token_ids = [ |
| 519 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 534 | tokenizer.encode(token, add_special_tokens=False) |
| 520 | for token in args.initializer_token | 535 | for token in args.initializer_token |
| 521 | ]) | 536 | ] |
| 522 | |||
| 523 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 524 | print(f"Added {num_added_tokens} new tokens.") | ||
| 525 | |||
| 526 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
| 527 | 537 | ||
| 528 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 538 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
| 529 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 530 | 539 | ||
| 531 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 540 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
| 532 | token_embeds = text_encoder.get_input_embeddings().weight.data | 541 | embeddings.add_embed(new_token.placeholder_id) |
| 542 | embeddings.add_embed(new_token.multi_ids, init_ids) | ||
| 533 | 543 | ||
| 534 | if args.resume_from is not None: | 544 | print(f"Added {len(new_tokens)} new tokens.") |
| 535 | resumepath = Path(args.resume_from).joinpath("checkpoints") | ||
| 536 | |||
| 537 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
| 538 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) | ||
| 539 | |||
| 540 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 541 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 542 | token_embeds[token_id] = embeddings | ||
| 543 | 545 | ||
| 544 | vae.requires_grad_(False) | 546 | vae.requires_grad_(False) |
| 545 | unet.requires_grad_(False) | 547 | unet.requires_grad_(False) |
| 546 | 548 | ||
| 547 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | ||
| 548 | |||
| 549 | text_encoder.text_model.encoder.requires_grad_(False) | 549 | text_encoder.text_model.encoder.requires_grad_(False) |
| 550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| @@ -575,7 +575,7 @@ def main(): | |||
| 575 | 575 | ||
| 576 | # Initialize the optimizer | 576 | # Initialize the optimizer |
| 577 | optimizer = optimizer_class( | 577 | optimizer = optimizer_class( |
| 578 | text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings | 578 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings |
| 579 | lr=args.learning_rate, | 579 | lr=args.learning_rate, |
| 580 | betas=(args.adam_beta1, args.adam_beta2), | 580 | betas=(args.adam_beta1, args.adam_beta2), |
| 581 | weight_decay=args.adam_weight_decay, | 581 | weight_decay=args.adam_weight_decay, |
| @@ -816,6 +816,7 @@ def main(): | |||
| 816 | config = vars(args).copy() | 816 | config = vars(args).copy() |
| 817 | config["initializer_token"] = " ".join(config["initializer_token"]) | 817 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 819 | config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) | ||
| 819 | if config["collection"] is not None: | 820 | if config["collection"] is not None: |
| 820 | config["collection"] = " ".join(config["collection"]) | 821 | config["collection"] = " ".join(config["collection"]) |
| 821 | if config["exclude_collections"] is not None: | 822 | if config["exclude_collections"] is not None: |
| @@ -852,8 +853,7 @@ def main(): | |||
| 852 | tokenizer=tokenizer, | 853 | tokenizer=tokenizer, |
| 853 | text_encoder=text_encoder, | 854 | text_encoder=text_encoder, |
| 854 | scheduler=checkpoint_scheduler, | 855 | scheduler=checkpoint_scheduler, |
| 855 | placeholder_token=args.placeholder_token, | 856 | new_tokens=new_tokens, |
| 856 | placeholder_token_id=placeholder_token_id, | ||
| 857 | output_dir=basepath, | 857 | output_dir=basepath, |
| 858 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
| 859 | sample_batch_size=args.sample_batch_size, | 859 | sample_batch_size=args.sample_batch_size, |
diff --git a/training/util.py b/training/util.py index d0f7fcd..43a55e1 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,5 +1,6 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | from typing import Iterable | ||
| 3 | 4 | ||
| 4 | import torch | 5 | import torch |
| 5 | from PIL import Image | 6 | from PIL import Image |
| @@ -39,8 +40,6 @@ class CheckpointerBase: | |||
| 39 | self, | 40 | self, |
| 40 | datamodule, | 41 | datamodule, |
| 41 | output_dir: Path, | 42 | output_dir: Path, |
| 42 | placeholder_token, | ||
| 43 | placeholder_token_id, | ||
| 44 | sample_image_size, | 43 | sample_image_size, |
| 45 | sample_batches, | 44 | sample_batches, |
| 46 | sample_batch_size, | 45 | sample_batch_size, |
| @@ -48,8 +47,6 @@ class CheckpointerBase: | |||
| 48 | ): | 47 | ): |
| 49 | self.datamodule = datamodule | 48 | self.datamodule = datamodule |
| 50 | self.output_dir = output_dir | 49 | self.output_dir = output_dir |
| 51 | self.placeholder_token = placeholder_token | ||
| 52 | self.placeholder_token_id = placeholder_token_id | ||
| 53 | self.sample_image_size = sample_image_size | 50 | self.sample_image_size = sample_image_size |
| 54 | self.seed = seed or torch.random.seed() | 51 | self.seed = seed or torch.random.seed() |
| 55 | self.sample_batches = sample_batches | 52 | self.sample_batches = sample_batches |
| @@ -117,3 +114,58 @@ class CheckpointerBase: | |||
| 117 | del image_grid | 114 | del image_grid |
| 118 | 115 | ||
| 119 | del generator | 116 | del generator |
| 117 | |||
| 118 | |||
| 119 | class EMAModel: | ||
| 120 | """ | ||
| 121 | Exponential Moving Average of models weights | ||
| 122 | """ | ||
| 123 | |||
| 124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | ||
| 125 | parameters = list(parameters) | ||
| 126 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
| 127 | |||
| 128 | self.decay = decay | ||
| 129 | self.optimization_step = 0 | ||
| 130 | |||
| 131 | @torch.no_grad() | ||
| 132 | def step(self, parameters): | ||
| 133 | parameters = list(parameters) | ||
| 134 | |||
| 135 | self.optimization_step += 1 | ||
| 136 | |||
| 137 | # Compute the decay factor for the exponential moving average. | ||
| 138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | ||
| 139 | one_minus_decay = 1 - min(self.decay, value) | ||
| 140 | |||
| 141 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 142 | if param.requires_grad: | ||
| 143 | s_param.sub_(one_minus_decay * (s_param - param)) | ||
| 144 | else: | ||
| 145 | s_param.copy_(param) | ||
| 146 | |||
| 147 | torch.cuda.empty_cache() | ||
| 148 | |||
| 149 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
| 150 | """ | ||
| 151 | Copy current averaged parameters into given collection of parameters. | ||
| 152 | Args: | ||
| 153 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
| 154 | updated with the stored moving averages. If `None`, the | ||
| 155 | parameters with which this `ExponentialMovingAverage` was | ||
| 156 | initialized will be used. | ||
| 157 | """ | ||
| 158 | parameters = list(parameters) | ||
| 159 | for s_param, param in zip(self.shadow_params, parameters): | ||
| 160 | param.data.copy_(s_param.data) | ||
| 161 | |||
| 162 | def to(self, device=None, dtype=None) -> None: | ||
| 163 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
| 164 | Args: | ||
| 165 | device: like `device` argument to `torch.Tensor.to` | ||
| 166 | """ | ||
| 167 | # .to() on the tensors handles None correctly | ||
| 168 | self.shadow_params = [ | ||
| 169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
| 170 | for p in self.shadow_params | ||
| 171 | ] | ||
