diff options
-rw-r--r-- | common.py | 38 | ||||
-rw-r--r-- | environment.yaml | 1 | ||||
-rw-r--r-- | infer.py | 14 | ||||
-rw-r--r-- | models/clip/embeddings.py | 109 | ||||
-rw-r--r-- | models/clip/prompt.py | 6 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 64 | ||||
-rw-r--r-- | train_ti.py | 88 | ||||
-rw-r--r-- | training/util.py | 60 |
8 files changed, 299 insertions, 81 deletions
@@ -1,9 +1,10 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | 3 | ||
4 | import torch | 4 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
5 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
5 | 6 | ||
6 | from transformers import CLIPTextModel, CLIPTokenizer | 7 | from safetensors import safe_open |
7 | 8 | ||
8 | 9 | ||
9 | def load_config(filename): | 10 | def load_config(filename): |
@@ -18,33 +19,20 @@ def load_config(filename): | |||
18 | return args | 19 | return args |
19 | 20 | ||
20 | 21 | ||
21 | def load_text_embedding(embeddings, token_id, file): | 22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): |
22 | data = torch.load(file, map_location="cpu") | ||
23 | |||
24 | assert len(data.keys()) == 1, 'embedding data has multiple terms in it' | ||
25 | |||
26 | emb = next(iter(data.values())) | ||
27 | if len(emb.shape) == 1: | ||
28 | emb = emb.unsqueeze(0) | ||
29 | |||
30 | embeddings[token_id] = emb | ||
31 | |||
32 | |||
33 | def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path): | ||
34 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
35 | return [] | 24 | return [] |
36 | 25 | ||
37 | files = [file for file in embeddings_dir.iterdir() if file.is_file()] | 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] |
38 | 27 | tokens = [filename.stem for filename in filenames] | |
39 | tokens = [file.stem for file in files] | ||
40 | added = tokenizer.add_tokens(tokens) | ||
41 | token_ids = tokenizer.convert_tokens_to_ids(tokens) | ||
42 | |||
43 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
44 | 28 | ||
45 | token_embeds = text_encoder.get_input_embeddings().weight.data | 29 | for filename in embeddings_dir.iterdir(): |
30 | if filename.is_file(): | ||
31 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
32 | embed = file.get_tensor("embed") | ||
46 | 33 | ||
47 | for (token_id, file) in zip(token_ids, files): | 34 | added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) |
48 | load_text_embedding(token_embeds, token_id, file) | 35 | embeddings.add_embed(added.placeholder_id) |
36 | embeddings.add_embed(added.multi_ids, embed) | ||
49 | 37 | ||
50 | return tokens | 38 | return tokens |
diff --git a/environment.yaml b/environment.yaml index c006379..7f0e903 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -18,6 +18,7 @@ dependencies: | |||
18 | - accelerate==0.15.0 | 18 | - accelerate==0.15.0 |
19 | - bitsandbytes==0.35.4 | 19 | - bitsandbytes==0.35.4 |
20 | - python-slugify>=6.1.2 | 20 | - python-slugify>=6.1.2 |
21 | - safetensors==0.2.7 | ||
21 | - setuptools==65.6.3 | 22 | - setuptools==65.6.3 |
22 | - test-tube>=0.7.5 | 23 | - test-tube>=0.7.5 |
23 | - transformers==4.25.1 | 24 | - transformers==4.25.1 |
@@ -8,6 +8,7 @@ from pathlib import Path | |||
8 | import torch | 8 | import torch |
9 | import json | 9 | import json |
10 | from PIL import Image | 10 | from PIL import Image |
11 | from slugify import slugify | ||
11 | from diffusers import ( | 12 | from diffusers import ( |
12 | AutoencoderKL, | 13 | AutoencoderKL, |
13 | UNet2DConditionModel, | 14 | UNet2DConditionModel, |
@@ -20,11 +21,12 @@ from diffusers import ( | |||
20 | KDPM2DiscreteScheduler, | 21 | KDPM2DiscreteScheduler, |
21 | KDPM2AncestralDiscreteScheduler | 22 | KDPM2AncestralDiscreteScheduler |
22 | ) | 23 | ) |
23 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel |
24 | from slugify import slugify | ||
25 | 25 | ||
26 | from models.clip.embeddings import patch_managed_embeddings | ||
27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 28 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from common import load_text_embeddings, load_config | 29 | from common import load_config, load_embeddings_from_dir |
28 | 30 | ||
29 | 31 | ||
30 | torch.backends.cuda.matmul.allow_tf32 = True | 32 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}): | |||
183 | def create_pipeline(model, embeddings_dir, dtype): | 185 | def create_pipeline(model, embeddings_dir, dtype): |
184 | print("Loading Stable Diffusion pipeline...") | 186 | print("Loading Stable Diffusion pipeline...") |
185 | 187 | ||
186 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 188 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
187 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 189 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) |
188 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 190 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
189 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 191 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
190 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 192 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
191 | 193 | ||
192 | added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) | 194 | embeddings = patch_managed_embeddings(text_encoder) |
195 | added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) | ||
196 | |||
193 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 197 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") |
194 | 198 | ||
195 | pipeline = VlpnStableDiffusion( | 199 | pipeline = VlpnStableDiffusion( |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py new file mode 100644 index 0000000..7d63ffb --- /dev/null +++ b/models/clip/embeddings.py | |||
@@ -0,0 +1,109 @@ | |||
1 | from typing import Union, Optional | ||
2 | from pathlib import Path | ||
3 | |||
4 | import torch | ||
5 | import torch.nn as nn | ||
6 | |||
7 | from safetensors import safe_open | ||
8 | from safetensors.torch import save_file | ||
9 | |||
10 | from transformers import CLIPTextModel | ||
11 | from transformers.models.clip import CLIPTextConfig | ||
12 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
13 | |||
14 | |||
15 | def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding: | ||
16 | old_num_embeddings, old_embedding_dim = old_embedding.weight.size() | ||
17 | |||
18 | new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim) | ||
19 | new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype) | ||
20 | new_embedding.weight.data.zero_() | ||
21 | new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data | ||
22 | |||
23 | return new_embedding | ||
24 | |||
25 | |||
26 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | ||
27 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): | ||
28 | super().__init__(config) | ||
29 | |||
30 | self.token_embedding = embeddings.token_embedding | ||
31 | self.position_embedding = embeddings.position_embedding | ||
32 | |||
33 | self.temp_token_embedding = nn.Embedding( | ||
34 | self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | ||
35 | self.temp_token_embedding.weight.data.zero_() | ||
36 | self.temp_token_ids = torch.tensor([]) | ||
37 | |||
38 | def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): | ||
39 | if isinstance(token_ids, int): | ||
40 | token_ids = [token_ids] | ||
41 | |||
42 | if initializer is not None: | ||
43 | if isinstance(initializer, int): | ||
44 | initializer = [initializer] | ||
45 | |||
46 | if isinstance(initializer, list): | ||
47 | initializer = (initializer * len(token_ids))[:len(token_ids)] | ||
48 | |||
49 | with torch.no_grad(): | ||
50 | initializer = self.get_embed(initializer) | ||
51 | |||
52 | self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids)) | ||
53 | self.token_embedding = expand_embedding(self.token_embedding, len(token_ids)) | ||
54 | |||
55 | token_ids = torch.tensor(token_ids) | ||
56 | |||
57 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | ||
58 | |||
59 | if initializer is not None: | ||
60 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
61 | else: | ||
62 | self.temp_token_embedding.weight.data[token_ids].zero_() | ||
63 | |||
64 | def load_embed(self, input_ids: list[int], filename: Path): | ||
65 | with safe_open(filename, framework="pt", device="cpu") as file: | ||
66 | self.add_embed(input_ids, file.get_tensor("embed")) | ||
67 | |||
68 | def save_embed(self, input_ids: list[int], filename: Path): | ||
69 | save_file({"embed": self.get_embed(input_ids)}, filename) | ||
70 | |||
71 | def make_permanent(self): | ||
72 | self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids] | ||
73 | self.temp_token_ids = torch.tensor([]) | ||
74 | |||
75 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | ||
76 | if isinstance(input_ids, list): | ||
77 | input_ids = torch.tensor(input_ids) | ||
78 | |||
79 | mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device)) | ||
80 | |||
81 | embeds = self.token_embedding(input_ids) | ||
82 | embeds[mask] = self.temp_token_embedding(input_ids)[mask] | ||
83 | |||
84 | return embeds | ||
85 | |||
86 | def forward( | ||
87 | self, | ||
88 | input_ids: Optional[torch.LongTensor] = None, | ||
89 | position_ids: Optional[torch.LongTensor] = None, | ||
90 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
91 | ) -> torch.Tensor: | ||
92 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
93 | |||
94 | if position_ids is None: | ||
95 | position_ids = self.position_ids[:, :seq_length] | ||
96 | |||
97 | if inputs_embeds is None: | ||
98 | inputs_embeds = self.get_embed(input_ids) | ||
99 | |||
100 | position_embeddings = self.position_embedding(position_ids) | ||
101 | embeddings = inputs_embeds + position_embeddings | ||
102 | |||
103 | return embeddings | ||
104 | |||
105 | |||
106 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: | ||
107 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | ||
108 | text_encoder.text_model.embeddings = text_embeddings | ||
109 | return text_embeddings | ||
diff --git a/models/clip/prompt.py b/models/clip/prompt.py index da33ecf..9da3955 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py | |||
@@ -1,4 +1,4 @@ | |||
1 | from typing import List, Union | 1 | from typing import Union |
2 | 2 | ||
3 | import torch | 3 | import torch |
4 | 4 | ||
@@ -10,13 +10,13 @@ class PromptProcessor(): | |||
10 | self.tokenizer = tokenizer | 10 | self.tokenizer = tokenizer |
11 | self.text_encoder = text_encoder | 11 | self.text_encoder = text_encoder |
12 | 12 | ||
13 | def get_input_ids(self, prompt: Union[str, List[str]]): | 13 | def get_input_ids(self, prompt: Union[str, list[str]]): |
14 | return self.tokenizer( | 14 | return self.tokenizer( |
15 | prompt, | 15 | prompt, |
16 | padding="do_not_pad", | 16 | padding="do_not_pad", |
17 | ).input_ids | 17 | ).input_ids |
18 | 18 | ||
19 | def unify_input_ids(self, input_ids: List[int]): | 19 | def unify_input_ids(self, input_ids: list[int]): |
20 | return self.tokenizer.pad( | 20 | return self.tokenizer.pad( |
21 | {"input_ids": input_ids}, | 21 | {"input_ids": input_ids}, |
22 | padding=True, | 22 | padding=True, |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py new file mode 100644 index 0000000..78871db --- /dev/null +++ b/models/clip/tokenizer.py | |||
@@ -0,0 +1,64 @@ | |||
1 | import copy | ||
2 | from typing import NamedTuple, Union | ||
3 | |||
4 | import numpy as np | ||
5 | |||
6 | from transformers import CLIPTokenizer | ||
7 | |||
8 | |||
9 | class MultiCLIPTokenizerItem(NamedTuple): | ||
10 | token: str | ||
11 | placeholder_id: int | ||
12 | multi_ids: list[int] | ||
13 | |||
14 | |||
15 | class MultiCLIPTokenizer(CLIPTokenizer): | ||
16 | def __init__(self, *args, **kwargs): | ||
17 | super().__init__(*args, **kwargs) | ||
18 | self.token_map: dict[int, list[int]] = {} | ||
19 | |||
20 | def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: | ||
21 | if isinstance(new_tokens, list): | ||
22 | if isinstance(num_vectors, int): | ||
23 | num_vectors = [num_vectors] * len(new_tokens) | ||
24 | |||
25 | if len(num_vectors) != len(new_tokens): | ||
26 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | ||
27 | |||
28 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | ||
29 | |||
30 | if isinstance(num_vectors, list): | ||
31 | raise ValueError("Expected num_vectors to be int for single token") | ||
32 | |||
33 | super().add_tokens(new_tokens) | ||
34 | |||
35 | if num_vectors == 1: | ||
36 | multi_token = [new_tokens] | ||
37 | else: | ||
38 | multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)] | ||
39 | super().add_tokens(multi_token) | ||
40 | |||
41 | meta_id = super().convert_tokens_to_ids(new_tokens) | ||
42 | multi_ids = super().convert_tokens_to_ids(multi_token) | ||
43 | |||
44 | self.token_map[meta_id] = multi_ids | ||
45 | |||
46 | return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids) | ||
47 | |||
48 | def encode(self, *args, vector_shuffle=True, **kwargs): | ||
49 | ids = super().encode(*args, **kwargs) | ||
50 | new_ids = [] | ||
51 | |||
52 | for id in ids: | ||
53 | if id in self.token_map: | ||
54 | tokens = self.token_map[id] | ||
55 | |||
56 | if vector_shuffle: | ||
57 | tokens = copy.copy(tokens) | ||
58 | np.random.shuffle(tokens) | ||
59 | |||
60 | new_ids = new_ids + self.token_map[id] | ||
61 | else: | ||
62 | new_ids.append(id) | ||
63 | |||
64 | return new_ids | ||
diff --git a/train_ti.py b/train_ti.py index 088c1a6..69d15ea 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, | |||
16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
17 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
18 | from tqdm.auto import tqdm | 18 | from tqdm.auto import tqdm |
19 | from transformers import CLIPTextModel, CLIPTokenizer | 19 | from transformers import CLIPTextModel |
20 | from slugify import slugify | 20 | from slugify import slugify |
21 | 21 | ||
22 | from common import load_text_embeddings, load_text_embedding, load_config | 22 | from common import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import CSVDataModule, CSVDataItem | 24 | from data.csv import CSVDataModule, CSVDataItem |
25 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
26 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
27 | from training.ti import patch_trainable_embeddings | ||
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 27 | from training.util import AverageMeter, CheckpointerBase, save_args |
28 | from models.clip.embeddings import patch_managed_embeddings | ||
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
30 | 31 | ||
31 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
32 | 33 | ||
@@ -81,6 +82,12 @@ def parse_args(): | |||
81 | help="A token to use as initializer word." | 82 | help="A token to use as initializer word." |
82 | ) | 83 | ) |
83 | parser.add_argument( | 84 | parser.add_argument( |
85 | "--num_vectors", | ||
86 | type=int, | ||
87 | nargs='*', | ||
88 | help="Number of vectors per embedding." | ||
89 | ) | ||
90 | parser.add_argument( | ||
84 | "--num_class_images", | 91 | "--num_class_images", |
85 | type=int, | 92 | type=int, |
86 | default=1, | 93 | default=1, |
@@ -360,8 +367,17 @@ def parse_args(): | |||
360 | if len(args.placeholder_token) == 0: | 367 | if len(args.placeholder_token) == 0: |
361 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 368 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
362 | 369 | ||
370 | if args.num_vectors is None: | ||
371 | args.num_vectors = 1 | ||
372 | |||
373 | if isinstance(args.num_vectors, int): | ||
374 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
375 | |||
363 | if len(args.placeholder_token) != len(args.initializer_token): | 376 | if len(args.placeholder_token) != len(args.initializer_token): |
364 | raise ValueError("You must specify --placeholder_token") | 377 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
378 | |||
379 | if len(args.placeholder_token) != len(args.num_vectors): | ||
380 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
365 | 381 | ||
366 | if isinstance(args.collection, str): | 382 | if isinstance(args.collection, str): |
367 | args.collection = [args.collection] | 383 | args.collection = [args.collection] |
@@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase): | |||
386 | tokenizer, | 402 | tokenizer, |
387 | text_encoder, | 403 | text_encoder, |
388 | scheduler, | 404 | scheduler, |
389 | placeholder_token, | 405 | new_tokens, |
390 | placeholder_token_id, | ||
391 | output_dir: Path, | 406 | output_dir: Path, |
392 | sample_image_size, | 407 | sample_image_size, |
393 | sample_batches, | 408 | sample_batches, |
@@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase): | |||
397 | super().__init__( | 412 | super().__init__( |
398 | datamodule=datamodule, | 413 | datamodule=datamodule, |
399 | output_dir=output_dir, | 414 | output_dir=output_dir, |
400 | placeholder_token=placeholder_token, | ||
401 | placeholder_token_id=placeholder_token_id, | ||
402 | sample_image_size=sample_image_size, | 415 | sample_image_size=sample_image_size, |
403 | seed=seed or torch.random.seed(), | 416 | seed=seed or torch.random.seed(), |
404 | sample_batches=sample_batches, | 417 | sample_batches=sample_batches, |
@@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase): | |||
412 | self.tokenizer = tokenizer | 425 | self.tokenizer = tokenizer |
413 | self.text_encoder = text_encoder | 426 | self.text_encoder = text_encoder |
414 | self.scheduler = scheduler | 427 | self.scheduler = scheduler |
428 | self.new_tokens = new_tokens | ||
415 | 429 | ||
416 | @torch.no_grad() | 430 | @torch.no_grad() |
417 | def checkpoint(self, step, postfix): | 431 | def checkpoint(self, step, postfix): |
@@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase): | |||
422 | 436 | ||
423 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 437 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
424 | 438 | ||
425 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 439 | for new_token in self.new_tokens: |
426 | # Save a checkpoint | 440 | text_encoder.text_model.embeddings.save_embed( |
427 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] | 441 | new_token.multi_ids, |
428 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 442 | f"{slugify(new_token.token)}_{step}_{postfix}.bin" |
429 | 443 | ) | |
430 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | ||
431 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
432 | 444 | ||
433 | del text_encoder | 445 | del text_encoder |
434 | del learned_embeds | 446 | del learned_embeds |
@@ -487,9 +499,9 @@ def main(): | |||
487 | 499 | ||
488 | # Load the tokenizer and add the placeholder token as a additional special token | 500 | # Load the tokenizer and add the placeholder token as a additional special token |
489 | if args.tokenizer_name: | 501 | if args.tokenizer_name: |
490 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 502 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
491 | elif args.pretrained_model_name_or_path: | 503 | elif args.pretrained_model_name_or_path: |
492 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 504 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
493 | 505 | ||
494 | # Load models and create wrapper for stable diffusion | 506 | # Load models and create wrapper for stable diffusion |
495 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 507 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -507,45 +519,33 @@ def main(): | |||
507 | unet.enable_gradient_checkpointing() | 519 | unet.enable_gradient_checkpointing() |
508 | text_encoder.gradient_checkpointing_enable() | 520 | text_encoder.gradient_checkpointing_enable() |
509 | 521 | ||
522 | embeddings = patch_managed_embeddings(text_encoder) | ||
523 | |||
510 | if args.embeddings_dir is not None: | 524 | if args.embeddings_dir is not None: |
511 | embeddings_dir = Path(args.embeddings_dir) | 525 | embeddings_dir = Path(args.embeddings_dir) |
512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 526 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
513 | raise ValueError("--embeddings_dir must point to an existing directory") | 527 | raise ValueError("--embeddings_dir must point to an existing directory") |
514 | added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 528 | |
529 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
515 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 530 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") |
516 | 531 | ||
517 | # Convert the initializer_token, placeholder_token to ids | 532 | # Convert the initializer_token, placeholder_token to ids |
518 | initializer_token_ids = torch.stack([ | 533 | initializer_token_ids = [ |
519 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 534 | tokenizer.encode(token, add_special_tokens=False) |
520 | for token in args.initializer_token | 535 | for token in args.initializer_token |
521 | ]) | 536 | ] |
522 | |||
523 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
524 | print(f"Added {num_added_tokens} new tokens.") | ||
525 | |||
526 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
527 | 537 | ||
528 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 538 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
529 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
530 | 539 | ||
531 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 540 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
532 | token_embeds = text_encoder.get_input_embeddings().weight.data | 541 | embeddings.add_embed(new_token.placeholder_id) |
542 | embeddings.add_embed(new_token.multi_ids, init_ids) | ||
533 | 543 | ||
534 | if args.resume_from is not None: | 544 | print(f"Added {len(new_tokens)} new tokens.") |
535 | resumepath = Path(args.resume_from).joinpath("checkpoints") | ||
536 | |||
537 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
538 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) | ||
539 | |||
540 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
541 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
542 | token_embeds[token_id] = embeddings | ||
543 | 545 | ||
544 | vae.requires_grad_(False) | 546 | vae.requires_grad_(False) |
545 | unet.requires_grad_(False) | 547 | unet.requires_grad_(False) |
546 | 548 | ||
547 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | ||
548 | |||
549 | text_encoder.text_model.encoder.requires_grad_(False) | 549 | text_encoder.text_model.encoder.requires_grad_(False) |
550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
@@ -575,7 +575,7 @@ def main(): | |||
575 | 575 | ||
576 | # Initialize the optimizer | 576 | # Initialize the optimizer |
577 | optimizer = optimizer_class( | 577 | optimizer = optimizer_class( |
578 | text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings | 578 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings |
579 | lr=args.learning_rate, | 579 | lr=args.learning_rate, |
580 | betas=(args.adam_beta1, args.adam_beta2), | 580 | betas=(args.adam_beta1, args.adam_beta2), |
581 | weight_decay=args.adam_weight_decay, | 581 | weight_decay=args.adam_weight_decay, |
@@ -816,6 +816,7 @@ def main(): | |||
816 | config = vars(args).copy() | 816 | config = vars(args).copy() |
817 | config["initializer_token"] = " ".join(config["initializer_token"]) | 817 | config["initializer_token"] = " ".join(config["initializer_token"]) |
818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
819 | config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) | ||
819 | if config["collection"] is not None: | 820 | if config["collection"] is not None: |
820 | config["collection"] = " ".join(config["collection"]) | 821 | config["collection"] = " ".join(config["collection"]) |
821 | if config["exclude_collections"] is not None: | 822 | if config["exclude_collections"] is not None: |
@@ -852,8 +853,7 @@ def main(): | |||
852 | tokenizer=tokenizer, | 853 | tokenizer=tokenizer, |
853 | text_encoder=text_encoder, | 854 | text_encoder=text_encoder, |
854 | scheduler=checkpoint_scheduler, | 855 | scheduler=checkpoint_scheduler, |
855 | placeholder_token=args.placeholder_token, | 856 | new_tokens=new_tokens, |
856 | placeholder_token_id=placeholder_token_id, | ||
857 | output_dir=basepath, | 857 | output_dir=basepath, |
858 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
859 | sample_batch_size=args.sample_batch_size, | 859 | sample_batch_size=args.sample_batch_size, |
diff --git a/training/util.py b/training/util.py index d0f7fcd..43a55e1 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,5 +1,6 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | from typing import Iterable | ||
3 | 4 | ||
4 | import torch | 5 | import torch |
5 | from PIL import Image | 6 | from PIL import Image |
@@ -39,8 +40,6 @@ class CheckpointerBase: | |||
39 | self, | 40 | self, |
40 | datamodule, | 41 | datamodule, |
41 | output_dir: Path, | 42 | output_dir: Path, |
42 | placeholder_token, | ||
43 | placeholder_token_id, | ||
44 | sample_image_size, | 43 | sample_image_size, |
45 | sample_batches, | 44 | sample_batches, |
46 | sample_batch_size, | 45 | sample_batch_size, |
@@ -48,8 +47,6 @@ class CheckpointerBase: | |||
48 | ): | 47 | ): |
49 | self.datamodule = datamodule | 48 | self.datamodule = datamodule |
50 | self.output_dir = output_dir | 49 | self.output_dir = output_dir |
51 | self.placeholder_token = placeholder_token | ||
52 | self.placeholder_token_id = placeholder_token_id | ||
53 | self.sample_image_size = sample_image_size | 50 | self.sample_image_size = sample_image_size |
54 | self.seed = seed or torch.random.seed() | 51 | self.seed = seed or torch.random.seed() |
55 | self.sample_batches = sample_batches | 52 | self.sample_batches = sample_batches |
@@ -117,3 +114,58 @@ class CheckpointerBase: | |||
117 | del image_grid | 114 | del image_grid |
118 | 115 | ||
119 | del generator | 116 | del generator |
117 | |||
118 | |||
119 | class EMAModel: | ||
120 | """ | ||
121 | Exponential Moving Average of models weights | ||
122 | """ | ||
123 | |||
124 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | ||
125 | parameters = list(parameters) | ||
126 | self.shadow_params = [p.clone().detach() for p in parameters] | ||
127 | |||
128 | self.decay = decay | ||
129 | self.optimization_step = 0 | ||
130 | |||
131 | @torch.no_grad() | ||
132 | def step(self, parameters): | ||
133 | parameters = list(parameters) | ||
134 | |||
135 | self.optimization_step += 1 | ||
136 | |||
137 | # Compute the decay factor for the exponential moving average. | ||
138 | value = (1 + self.optimization_step) / (10 + self.optimization_step) | ||
139 | one_minus_decay = 1 - min(self.decay, value) | ||
140 | |||
141 | for s_param, param in zip(self.shadow_params, parameters): | ||
142 | if param.requires_grad: | ||
143 | s_param.sub_(one_minus_decay * (s_param - param)) | ||
144 | else: | ||
145 | s_param.copy_(param) | ||
146 | |||
147 | torch.cuda.empty_cache() | ||
148 | |||
149 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: | ||
150 | """ | ||
151 | Copy current averaged parameters into given collection of parameters. | ||
152 | Args: | ||
153 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be | ||
154 | updated with the stored moving averages. If `None`, the | ||
155 | parameters with which this `ExponentialMovingAverage` was | ||
156 | initialized will be used. | ||
157 | """ | ||
158 | parameters = list(parameters) | ||
159 | for s_param, param in zip(self.shadow_params, parameters): | ||
160 | param.data.copy_(s_param.data) | ||
161 | |||
162 | def to(self, device=None, dtype=None) -> None: | ||
163 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. | ||
164 | Args: | ||
165 | device: like `device` argument to `torch.Tensor.to` | ||
166 | """ | ||
167 | # .to() on the tensors handles None correctly | ||
168 | self.shadow_params = [ | ||
169 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) | ||
170 | for p in self.shadow_params | ||
171 | ] | ||