summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common.py38
-rw-r--r--environment.yaml1
-rw-r--r--infer.py14
-rw-r--r--models/clip/embeddings.py109
-rw-r--r--models/clip/prompt.py6
-rw-r--r--models/clip/tokenizer.py64
-rw-r--r--train_ti.py88
-rw-r--r--training/util.py60
8 files changed, 299 insertions, 81 deletions
diff --git a/common.py b/common.py
index f369475..e8d3ac1 100644
--- a/common.py
+++ b/common.py
@@ -1,9 +1,10 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3 3
4import torch 4from models.clip.embeddings import ManagedCLIPTextEmbeddings
5from models.clip.tokenizer import MultiCLIPTokenizer
5 6
6from transformers import CLIPTextModel, CLIPTokenizer 7from safetensors import safe_open
7 8
8 9
9def load_config(filename): 10def load_config(filename):
@@ -18,33 +19,20 @@ def load_config(filename):
18 return args 19 return args
19 20
20 21
21def load_text_embedding(embeddings, token_id, file): 22def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path):
22 data = torch.load(file, map_location="cpu")
23
24 assert len(data.keys()) == 1, 'embedding data has multiple terms in it'
25
26 emb = next(iter(data.values()))
27 if len(emb.shape) == 1:
28 emb = emb.unsqueeze(0)
29
30 embeddings[token_id] = emb
31
32
33def load_text_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, embeddings_dir: Path):
34 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 23 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
35 return [] 24 return []
36 25
37 files = [file for file in embeddings_dir.iterdir() if file.is_file()] 26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
38 27 tokens = [filename.stem for filename in filenames]
39 tokens = [file.stem for file in files]
40 added = tokenizer.add_tokens(tokens)
41 token_ids = tokenizer.convert_tokens_to_ids(tokens)
42
43 text_encoder.resize_token_embeddings(len(tokenizer))
44 28
45 token_embeds = text_encoder.get_input_embeddings().weight.data 29 for filename in embeddings_dir.iterdir():
30 if filename.is_file():
31 with safe_open(filename, framework="pt", device="cpu") as file:
32 embed = file.get_tensor("embed")
46 33
47 for (token_id, file) in zip(token_ids, files): 34 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
48 load_text_embedding(token_embeds, token_id, file) 35 embeddings.add_embed(added.placeholder_id)
36 embeddings.add_embed(added.multi_ids, embed)
49 37
50 return tokens 38 return tokens
diff --git a/environment.yaml b/environment.yaml
index c006379..7f0e903 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -18,6 +18,7 @@ dependencies:
18 - accelerate==0.15.0 18 - accelerate==0.15.0
19 - bitsandbytes==0.35.4 19 - bitsandbytes==0.35.4
20 - python-slugify>=6.1.2 20 - python-slugify>=6.1.2
21 - safetensors==0.2.7
21 - setuptools==65.6.3 22 - setuptools==65.6.3
22 - test-tube>=0.7.5 23 - test-tube>=0.7.5
23 - transformers==4.25.1 24 - transformers==4.25.1
diff --git a/infer.py b/infer.py
index ae0b4da..4bcaff5 100644
--- a/infer.py
+++ b/infer.py
@@ -8,6 +8,7 @@ from pathlib import Path
8import torch 8import torch
9import json 9import json
10from PIL import Image 10from PIL import Image
11from slugify import slugify
11from diffusers import ( 12from diffusers import (
12 AutoencoderKL, 13 AutoencoderKL,
13 UNet2DConditionModel, 14 UNet2DConditionModel,
@@ -20,11 +21,12 @@ from diffusers import (
20 KDPM2DiscreteScheduler, 21 KDPM2DiscreteScheduler,
21 KDPM2AncestralDiscreteScheduler 22 KDPM2AncestralDiscreteScheduler
22) 23)
23from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel
24from slugify import slugify
25 25
26from models.clip.embeddings import patch_managed_embeddings
27from models.clip.tokenizer import MultiCLIPTokenizer
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 28from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from common import load_text_embeddings, load_config 29from common import load_config, load_embeddings_from_dir
28 30
29 31
30torch.backends.cuda.matmul.allow_tf32 = True 32torch.backends.cuda.matmul.allow_tf32 = True
@@ -183,13 +185,15 @@ def save_args(basepath, args, extra={}):
183def create_pipeline(model, embeddings_dir, dtype): 185def create_pipeline(model, embeddings_dir, dtype):
184 print("Loading Stable Diffusion pipeline...") 186 print("Loading Stable Diffusion pipeline...")
185 187
186 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 188 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
187 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 189 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
188 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 190 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
189 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 191 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
190 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 192 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
191 193
192 added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) 194 embeddings = patch_managed_embeddings(text_encoder)
195 added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir))
196
193 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 197 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
194 198
195 pipeline = VlpnStableDiffusion( 199 pipeline = VlpnStableDiffusion(
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
new file mode 100644
index 0000000..7d63ffb
--- /dev/null
+++ b/models/clip/embeddings.py
@@ -0,0 +1,109 @@
1from typing import Union, Optional
2from pathlib import Path
3
4import torch
5import torch.nn as nn
6
7from safetensors import safe_open
8from safetensors.torch import save_file
9
10from transformers import CLIPTextModel
11from transformers.models.clip import CLIPTextConfig
12from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
13
14
15def expand_embedding(old_embedding: nn.Embedding, n: int) -> nn.Embedding:
16 old_num_embeddings, old_embedding_dim = old_embedding.weight.size()
17
18 new_embedding = nn.Embedding(old_num_embeddings + n, old_embedding_dim)
19 new_embedding.to(old_embedding.weight.device, dtype=old_embedding.weight.dtype)
20 new_embedding.weight.data.zero_()
21 new_embedding.weight.data[:old_num_embeddings] = old_embedding.weight.data
22
23 return new_embedding
24
25
26class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
27 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings):
28 super().__init__(config)
29
30 self.token_embedding = embeddings.token_embedding
31 self.position_embedding = embeddings.position_embedding
32
33 self.temp_token_embedding = nn.Embedding(
34 self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
35 self.temp_token_embedding.weight.data.zero_()
36 self.temp_token_ids = torch.tensor([])
37
38 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
39 if isinstance(token_ids, int):
40 token_ids = [token_ids]
41
42 if initializer is not None:
43 if isinstance(initializer, int):
44 initializer = [initializer]
45
46 if isinstance(initializer, list):
47 initializer = (initializer * len(token_ids))[:len(token_ids)]
48
49 with torch.no_grad():
50 initializer = self.get_embed(initializer)
51
52 self.temp_token_embedding = expand_embedding(self.temp_token_embedding, len(token_ids))
53 self.token_embedding = expand_embedding(self.token_embedding, len(token_ids))
54
55 token_ids = torch.tensor(token_ids)
56
57 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
58
59 if initializer is not None:
60 self.temp_token_embedding.weight.data[token_ids] = initializer
61 else:
62 self.temp_token_embedding.weight.data[token_ids].zero_()
63
64 def load_embed(self, input_ids: list[int], filename: Path):
65 with safe_open(filename, framework="pt", device="cpu") as file:
66 self.add_embed(input_ids, file.get_tensor("embed"))
67
68 def save_embed(self, input_ids: list[int], filename: Path):
69 save_file({"embed": self.get_embed(input_ids)}, filename)
70
71 def make_permanent(self):
72 self.token_embedding.weight.data[self.temp_token_ids] = self.temp_token_embedding.weight.data[self.temp_token_ids]
73 self.temp_token_ids = torch.tensor([])
74
75 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
76 if isinstance(input_ids, list):
77 input_ids = torch.tensor(input_ids)
78
79 mask = torch.isin(input_ids, torch.tensor(self.temp_token_ids, device=input_ids.device))
80
81 embeds = self.token_embedding(input_ids)
82 embeds[mask] = self.temp_token_embedding(input_ids)[mask]
83
84 return embeds
85
86 def forward(
87 self,
88 input_ids: Optional[torch.LongTensor] = None,
89 position_ids: Optional[torch.LongTensor] = None,
90 inputs_embeds: Optional[torch.FloatTensor] = None,
91 ) -> torch.Tensor:
92 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
93
94 if position_ids is None:
95 position_ids = self.position_ids[:, :seq_length]
96
97 if inputs_embeds is None:
98 inputs_embeds = self.get_embed(input_ids)
99
100 position_embeddings = self.position_embedding(position_ids)
101 embeddings = inputs_embeds + position_embeddings
102
103 return embeddings
104
105
106def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings:
107 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
108 text_encoder.text_model.embeddings = text_embeddings
109 return text_embeddings
diff --git a/models/clip/prompt.py b/models/clip/prompt.py
index da33ecf..9da3955 100644
--- a/models/clip/prompt.py
+++ b/models/clip/prompt.py
@@ -1,4 +1,4 @@
1from typing import List, Union 1from typing import Union
2 2
3import torch 3import torch
4 4
@@ -10,13 +10,13 @@ class PromptProcessor():
10 self.tokenizer = tokenizer 10 self.tokenizer = tokenizer
11 self.text_encoder = text_encoder 11 self.text_encoder = text_encoder
12 12
13 def get_input_ids(self, prompt: Union[str, List[str]]): 13 def get_input_ids(self, prompt: Union[str, list[str]]):
14 return self.tokenizer( 14 return self.tokenizer(
15 prompt, 15 prompt,
16 padding="do_not_pad", 16 padding="do_not_pad",
17 ).input_ids 17 ).input_ids
18 18
19 def unify_input_ids(self, input_ids: List[int]): 19 def unify_input_ids(self, input_ids: list[int]):
20 return self.tokenizer.pad( 20 return self.tokenizer.pad(
21 {"input_ids": input_ids}, 21 {"input_ids": input_ids},
22 padding=True, 22 padding=True,
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
new file mode 100644
index 0000000..78871db
--- /dev/null
+++ b/models/clip/tokenizer.py
@@ -0,0 +1,64 @@
1import copy
2from typing import NamedTuple, Union
3
4import numpy as np
5
6from transformers import CLIPTokenizer
7
8
9class MultiCLIPTokenizerItem(NamedTuple):
10 token: str
11 placeholder_id: int
12 multi_ids: list[int]
13
14
15class MultiCLIPTokenizer(CLIPTokenizer):
16 def __init__(self, *args, **kwargs):
17 super().__init__(*args, **kwargs)
18 self.token_map: dict[int, list[int]] = {}
19
20 def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem:
21 if isinstance(new_tokens, list):
22 if isinstance(num_vectors, int):
23 num_vectors = [num_vectors] * len(new_tokens)
24
25 if len(num_vectors) != len(new_tokens):
26 raise ValueError("Expected new_tokens and num_vectors to have the same len")
27
28 return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)]
29
30 if isinstance(num_vectors, list):
31 raise ValueError("Expected num_vectors to be int for single token")
32
33 super().add_tokens(new_tokens)
34
35 if num_vectors == 1:
36 multi_token = [new_tokens]
37 else:
38 multi_token = [f"{new_tokens}_{i}" for i in range(num_vectors)]
39 super().add_tokens(multi_token)
40
41 meta_id = super().convert_tokens_to_ids(new_tokens)
42 multi_ids = super().convert_tokens_to_ids(multi_token)
43
44 self.token_map[meta_id] = multi_ids
45
46 return MultiCLIPTokenizerItem(new_tokens, meta_id, multi_ids)
47
48 def encode(self, *args, vector_shuffle=True, **kwargs):
49 ids = super().encode(*args, **kwargs)
50 new_ids = []
51
52 for id in ids:
53 if id in self.token_map:
54 tokens = self.token_map[id]
55
56 if vector_shuffle:
57 tokens = copy.copy(tokens)
58 np.random.shuffle(tokens)
59
60 new_ids = new_ids + self.token_map[id]
61 else:
62 new_ids.append(id)
63
64 return new_ids
diff --git a/train_ti.py b/train_ti.py
index 088c1a6..69d15ea 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler,
16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt 17import matplotlib.pyplot as plt
18from tqdm.auto import tqdm 18from tqdm.auto import tqdm
19from transformers import CLIPTextModel, CLIPTokenizer 19from transformers import CLIPTextModel
20from slugify import slugify 20from slugify import slugify
21 21
22from common import load_text_embeddings, load_text_embedding, load_config 22from common import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import CSVDataModule, CSVDataItem 24from data.csv import CSVDataModule, CSVDataItem
25from training.optimization import get_one_cycle_schedule 25from training.optimization import get_one_cycle_schedule
26from training.lr import LRFinder 26from training.lr import LRFinder
27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args 27from training.util import AverageMeter, CheckpointerBase, save_args
28from models.clip.embeddings import patch_managed_embeddings
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30from models.clip.tokenizer import MultiCLIPTokenizer
30 31
31logger = get_logger(__name__) 32logger = get_logger(__name__)
32 33
@@ -81,6 +82,12 @@ def parse_args():
81 help="A token to use as initializer word." 82 help="A token to use as initializer word."
82 ) 83 )
83 parser.add_argument( 84 parser.add_argument(
85 "--num_vectors",
86 type=int,
87 nargs='*',
88 help="Number of vectors per embedding."
89 )
90 parser.add_argument(
84 "--num_class_images", 91 "--num_class_images",
85 type=int, 92 type=int,
86 default=1, 93 default=1,
@@ -360,8 +367,17 @@ def parse_args():
360 if len(args.placeholder_token) == 0: 367 if len(args.placeholder_token) == 0:
361 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 368 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
362 369
370 if args.num_vectors is None:
371 args.num_vectors = 1
372
373 if isinstance(args.num_vectors, int):
374 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
375
363 if len(args.placeholder_token) != len(args.initializer_token): 376 if len(args.placeholder_token) != len(args.initializer_token):
364 raise ValueError("You must specify --placeholder_token") 377 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
378
379 if len(args.placeholder_token) != len(args.num_vectors):
380 raise ValueError("--placeholder_token and --num_vectors must have the same number of items")
365 381
366 if isinstance(args.collection, str): 382 if isinstance(args.collection, str):
367 args.collection = [args.collection] 383 args.collection = [args.collection]
@@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase):
386 tokenizer, 402 tokenizer,
387 text_encoder, 403 text_encoder,
388 scheduler, 404 scheduler,
389 placeholder_token, 405 new_tokens,
390 placeholder_token_id,
391 output_dir: Path, 406 output_dir: Path,
392 sample_image_size, 407 sample_image_size,
393 sample_batches, 408 sample_batches,
@@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase):
397 super().__init__( 412 super().__init__(
398 datamodule=datamodule, 413 datamodule=datamodule,
399 output_dir=output_dir, 414 output_dir=output_dir,
400 placeholder_token=placeholder_token,
401 placeholder_token_id=placeholder_token_id,
402 sample_image_size=sample_image_size, 415 sample_image_size=sample_image_size,
403 seed=seed or torch.random.seed(), 416 seed=seed or torch.random.seed(),
404 sample_batches=sample_batches, 417 sample_batches=sample_batches,
@@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase):
412 self.tokenizer = tokenizer 425 self.tokenizer = tokenizer
413 self.text_encoder = text_encoder 426 self.text_encoder = text_encoder
414 self.scheduler = scheduler 427 self.scheduler = scheduler
428 self.new_tokens = new_tokens
415 429
416 @torch.no_grad() 430 @torch.no_grad()
417 def checkpoint(self, step, postfix): 431 def checkpoint(self, step, postfix):
@@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase):
422 436
423 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 437 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
424 438
425 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 439 for new_token in self.new_tokens:
426 # Save a checkpoint 440 text_encoder.text_model.embeddings.save_embed(
427 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] 441 new_token.multi_ids,
428 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 442 f"{slugify(new_token.token)}_{step}_{postfix}.bin"
429 443 )
430 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
431 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
432 444
433 del text_encoder 445 del text_encoder
434 del learned_embeds 446 del learned_embeds
@@ -487,9 +499,9 @@ def main():
487 499
488 # Load the tokenizer and add the placeholder token as a additional special token 500 # Load the tokenizer and add the placeholder token as a additional special token
489 if args.tokenizer_name: 501 if args.tokenizer_name:
490 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 502 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
491 elif args.pretrained_model_name_or_path: 503 elif args.pretrained_model_name_or_path:
492 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 504 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
493 505
494 # Load models and create wrapper for stable diffusion 506 # Load models and create wrapper for stable diffusion
495 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 507 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -507,45 +519,33 @@ def main():
507 unet.enable_gradient_checkpointing() 519 unet.enable_gradient_checkpointing()
508 text_encoder.gradient_checkpointing_enable() 520 text_encoder.gradient_checkpointing_enable()
509 521
522 embeddings = patch_managed_embeddings(text_encoder)
523
510 if args.embeddings_dir is not None: 524 if args.embeddings_dir is not None:
511 embeddings_dir = Path(args.embeddings_dir) 525 embeddings_dir = Path(args.embeddings_dir)
512 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 526 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
513 raise ValueError("--embeddings_dir must point to an existing directory") 527 raise ValueError("--embeddings_dir must point to an existing directory")
514 added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) 528
529 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
515 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") 530 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
516 531
517 # Convert the initializer_token, placeholder_token to ids 532 # Convert the initializer_token, placeholder_token to ids
518 initializer_token_ids = torch.stack([ 533 initializer_token_ids = [
519 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 534 tokenizer.encode(token, add_special_tokens=False)
520 for token in args.initializer_token 535 for token in args.initializer_token
521 ]) 536 ]
522
523 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
524 print(f"Added {num_added_tokens} new tokens.")
525
526 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
527 537
528 # Resize the token embeddings as we are adding new special tokens to the tokenizer 538 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
529 text_encoder.resize_token_embeddings(len(tokenizer))
530 539
531 # Initialise the newly added placeholder token with the embeddings of the initializer token 540 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
532 token_embeds = text_encoder.get_input_embeddings().weight.data 541 embeddings.add_embed(new_token.placeholder_id)
542 embeddings.add_embed(new_token.multi_ids, init_ids)
533 543
534 if args.resume_from is not None: 544 print(f"Added {len(new_tokens)} new tokens.")
535 resumepath = Path(args.resume_from).joinpath("checkpoints")
536
537 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
538 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
539
540 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
541 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
542 token_embeds[token_id] = embeddings
543 545
544 vae.requires_grad_(False) 546 vae.requires_grad_(False)
545 unet.requires_grad_(False) 547 unet.requires_grad_(False)
546 548
547 patch_trainable_embeddings(text_encoder, placeholder_token_id)
548
549 text_encoder.text_model.encoder.requires_grad_(False) 549 text_encoder.text_model.encoder.requires_grad_(False)
550 text_encoder.text_model.final_layer_norm.requires_grad_(False) 550 text_encoder.text_model.final_layer_norm.requires_grad_(False)
551 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 551 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
@@ -575,7 +575,7 @@ def main():
575 575
576 # Initialize the optimizer 576 # Initialize the optimizer
577 optimizer = optimizer_class( 577 optimizer = optimizer_class(
578 text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings 578 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings
579 lr=args.learning_rate, 579 lr=args.learning_rate,
580 betas=(args.adam_beta1, args.adam_beta2), 580 betas=(args.adam_beta1, args.adam_beta2),
581 weight_decay=args.adam_weight_decay, 581 weight_decay=args.adam_weight_decay,
@@ -816,6 +816,7 @@ def main():
816 config = vars(args).copy() 816 config = vars(args).copy()
817 config["initializer_token"] = " ".join(config["initializer_token"]) 817 config["initializer_token"] = " ".join(config["initializer_token"])
818 config["placeholder_token"] = " ".join(config["placeholder_token"]) 818 config["placeholder_token"] = " ".join(config["placeholder_token"])
819 config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]])
819 if config["collection"] is not None: 820 if config["collection"] is not None:
820 config["collection"] = " ".join(config["collection"]) 821 config["collection"] = " ".join(config["collection"])
821 if config["exclude_collections"] is not None: 822 if config["exclude_collections"] is not None:
@@ -852,8 +853,7 @@ def main():
852 tokenizer=tokenizer, 853 tokenizer=tokenizer,
853 text_encoder=text_encoder, 854 text_encoder=text_encoder,
854 scheduler=checkpoint_scheduler, 855 scheduler=checkpoint_scheduler,
855 placeholder_token=args.placeholder_token, 856 new_tokens=new_tokens,
856 placeholder_token_id=placeholder_token_id,
857 output_dir=basepath, 857 output_dir=basepath,
858 sample_image_size=args.sample_image_size, 858 sample_image_size=args.sample_image_size,
859 sample_batch_size=args.sample_batch_size, 859 sample_batch_size=args.sample_batch_size,
diff --git a/training/util.py b/training/util.py
index d0f7fcd..43a55e1 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,5 +1,6 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3from typing import Iterable
3 4
4import torch 5import torch
5from PIL import Image 6from PIL import Image
@@ -39,8 +40,6 @@ class CheckpointerBase:
39 self, 40 self,
40 datamodule, 41 datamodule,
41 output_dir: Path, 42 output_dir: Path,
42 placeholder_token,
43 placeholder_token_id,
44 sample_image_size, 43 sample_image_size,
45 sample_batches, 44 sample_batches,
46 sample_batch_size, 45 sample_batch_size,
@@ -48,8 +47,6 @@ class CheckpointerBase:
48 ): 47 ):
49 self.datamodule = datamodule 48 self.datamodule = datamodule
50 self.output_dir = output_dir 49 self.output_dir = output_dir
51 self.placeholder_token = placeholder_token
52 self.placeholder_token_id = placeholder_token_id
53 self.sample_image_size = sample_image_size 50 self.sample_image_size = sample_image_size
54 self.seed = seed or torch.random.seed() 51 self.seed = seed or torch.random.seed()
55 self.sample_batches = sample_batches 52 self.sample_batches = sample_batches
@@ -117,3 +114,58 @@ class CheckpointerBase:
117 del image_grid 114 del image_grid
118 115
119 del generator 116 del generator
117
118
119class EMAModel:
120 """
121 Exponential Moving Average of models weights
122 """
123
124 def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
125 parameters = list(parameters)
126 self.shadow_params = [p.clone().detach() for p in parameters]
127
128 self.decay = decay
129 self.optimization_step = 0
130
131 @torch.no_grad()
132 def step(self, parameters):
133 parameters = list(parameters)
134
135 self.optimization_step += 1
136
137 # Compute the decay factor for the exponential moving average.
138 value = (1 + self.optimization_step) / (10 + self.optimization_step)
139 one_minus_decay = 1 - min(self.decay, value)
140
141 for s_param, param in zip(self.shadow_params, parameters):
142 if param.requires_grad:
143 s_param.sub_(one_minus_decay * (s_param - param))
144 else:
145 s_param.copy_(param)
146
147 torch.cuda.empty_cache()
148
149 def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
150 """
151 Copy current averaged parameters into given collection of parameters.
152 Args:
153 parameters: Iterable of `torch.nn.Parameter`; the parameters to be
154 updated with the stored moving averages. If `None`, the
155 parameters with which this `ExponentialMovingAverage` was
156 initialized will be used.
157 """
158 parameters = list(parameters)
159 for s_param, param in zip(self.shadow_params, parameters):
160 param.data.copy_(s_param.data)
161
162 def to(self, device=None, dtype=None) -> None:
163 r"""Move internal buffers of the ExponentialMovingAverage to `device`.
164 Args:
165 device: like `device` argument to `torch.Tensor.to`
166 """
167 # .to() on the tensors handles None correctly
168 self.shadow_params = [
169 p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
170 for p in self.shadow_params
171 ]