diff options
-rw-r--r-- | data/csv.py | 64 | ||||
-rw-r--r-- | dreambooth.py | 68 | ||||
-rw-r--r-- | textual_inversion.py | 58 |
3 files changed, 115 insertions, 75 deletions
diff --git a/data/csv.py b/data/csv.py index 5144c0a..f9b5e39 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,16 +1,20 @@ | |||
1 | import math | 1 | import math |
2 | import pandas as pd | ||
3 | import torch | 2 | import torch |
3 | import json | ||
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
6 | from PIL import Image | 6 | from PIL import Image |
7 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
8 | from torchvision import transforms | 8 | from torchvision import transforms |
9 | from typing import NamedTuple, List, Optional | 9 | from typing import Dict, NamedTuple, List, Optional, Union |
10 | 10 | ||
11 | from models.clip.prompt import PromptProcessor | 11 | from models.clip.prompt import PromptProcessor |
12 | 12 | ||
13 | 13 | ||
14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | ||
15 | return {"content": prompt} if isinstance(prompt, str) else prompt | ||
16 | |||
17 | |||
14 | class CSVDataItem(NamedTuple): | 18 | class CSVDataItem(NamedTuple): |
15 | instance_image_path: Path | 19 | instance_image_path: Path |
16 | class_image_path: Path | 20 | class_image_path: Path |
@@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule): | |||
60 | self.collate_fn = collate_fn | 64 | self.collate_fn = collate_fn |
61 | self.batch_size = batch_size | 65 | self.batch_size = batch_size |
62 | 66 | ||
63 | def prepare_subdata(self, data, num_class_images=1): | 67 | def prepare_subdata(self, template, data, num_class_images=1): |
68 | image = template["image"] if "image" in template else "{}" | ||
69 | prompt = template["prompt"] if "prompt" in template else "{content}" | ||
70 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | ||
71 | |||
64 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | 72 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) |
65 | 73 | ||
66 | return [ | 74 | return [ |
67 | CSVDataItem( | 75 | CSVDataItem( |
68 | self.data_root.joinpath(item.image), | 76 | self.data_root.joinpath(image.format(item["image"])), |
69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 77 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), |
70 | item.prompt, | 78 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
71 | item.nprompt | 79 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
72 | ) | 80 | ) |
73 | for item in data | 81 | for item in data |
74 | for i in range(image_multiplier) | 82 | for i in range(image_multiplier) |
75 | ] | 83 | ] |
76 | 84 | ||
77 | def prepare_data(self): | 85 | def prepare_data(self): |
78 | metadata = pd.read_json(self.data_file) | 86 | with open(self.data_file, 'rt') as f: |
79 | metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] | 87 | metadata = json.load(f) |
80 | num_images = len(metadata) | 88 | template = metadata["template"] if "template" in metadata else {} |
89 | items = metadata["items"] if "items" in metadata else [] | ||
90 | |||
91 | items = [item for item in items if not "skip" in item or item["skip"] != True] | ||
92 | num_images = len(items) | ||
81 | 93 | ||
82 | valid_set_size = int(num_images * 0.2) | 94 | valid_set_size = int(num_images * 0.2) |
83 | if self.valid_set_size: | 95 | if self.valid_set_size: |
@@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
85 | valid_set_size = max(valid_set_size, 1) | 97 | valid_set_size = max(valid_set_size, 1) |
86 | train_set_size = num_images - valid_set_size | 98 | train_set_size = num_images - valid_set_size |
87 | 99 | ||
88 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) | 100 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
89 | 101 | ||
90 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) | 102 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) |
91 | self.data_val = self.prepare_subdata(data_val) | 103 | self.data_val = self.prepare_subdata(template, data_val) |
92 | 104 | ||
93 | def setup(self, stage=None): | 105 | def setup(self, stage=None): |
94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 106 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
@@ -133,8 +145,8 @@ class CSVDataset(Dataset): | |||
133 | self.instance_identifier = instance_identifier | 145 | self.instance_identifier = instance_identifier |
134 | self.class_identifier = class_identifier | 146 | self.class_identifier = class_identifier |
135 | self.num_class_images = num_class_images | 147 | self.num_class_images = num_class_images |
136 | self.cache = {} | ||
137 | self.image_cache = {} | 148 | self.image_cache = {} |
149 | self.input_id_cache = {} | ||
138 | 150 | ||
139 | self.num_instance_images = len(self.data) | 151 | self.num_instance_images = len(self.data) |
140 | self._length = self.num_instance_images * repeats | 152 | self._length = self.num_instance_images * repeats |
@@ -168,12 +180,19 @@ class CSVDataset(Dataset): | |||
168 | 180 | ||
169 | return image | 181 | return image |
170 | 182 | ||
183 | def get_input_ids(self, prompt, identifier): | ||
184 | prompt = prompt.format(identifier) | ||
185 | |||
186 | if prompt in self.input_id_cache: | ||
187 | return self.input_id_cache[prompt] | ||
188 | |||
189 | input_ids = self.prompt_processor.get_input_ids(prompt) | ||
190 | self.input_id_cache[prompt] = input_ids | ||
191 | |||
192 | return input_ids | ||
193 | |||
171 | def get_example(self, i): | 194 | def get_example(self, i): |
172 | item = self.data[i % self.num_instance_images] | 195 | item = self.data[i % self.num_instance_images] |
173 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" | ||
174 | |||
175 | if cache_key in self.cache: | ||
176 | return self.cache[cache_key] | ||
177 | 196 | ||
178 | example = {} | 197 | example = {} |
179 | 198 | ||
@@ -181,17 +200,12 @@ class CSVDataset(Dataset): | |||
181 | example["nprompts"] = item.nprompt | 200 | example["nprompts"] = item.nprompt |
182 | 201 | ||
183 | example["instance_images"] = self.get_image(item.instance_image_path) | 202 | example["instance_images"] = self.get_image(item.instance_image_path) |
184 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 203 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) |
185 | item.prompt.format(self.instance_identifier) | ||
186 | ) | ||
187 | 204 | ||
188 | if self.num_class_images != 0: | 205 | if self.num_class_images != 0: |
189 | example["class_images"] = self.get_image(item.class_image_path) | 206 | example["class_images"] = self.get_image(item.class_image_path) |
190 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( | 207 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) |
191 | item.nprompt.format(self.class_identifier) | ||
192 | ) | ||
193 | 208 | ||
194 | self.cache[cache_key] = example | ||
195 | return example | 209 | return example |
196 | 210 | ||
197 | def __getitem__(self, i): | 211 | def __getitem__(self, i): |
diff --git a/dreambooth.py b/dreambooth.py index 5c26f12..2c24908 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -71,13 +71,13 @@ def parse_args(): | |||
71 | parser.add_argument( | 71 | parser.add_argument( |
72 | "--placeholder_token", | 72 | "--placeholder_token", |
73 | type=str, | 73 | type=str, |
74 | default="<*>", | 74 | nargs='*', |
75 | help="A token to use as a placeholder for the concept.", | 75 | help="A token to use as a placeholder for the concept.", |
76 | ) | 76 | ) |
77 | parser.add_argument( | 77 | parser.add_argument( |
78 | "--initializer_token", | 78 | "--initializer_token", |
79 | type=str, | 79 | type=str, |
80 | default=None, | 80 | nargs='*', |
81 | help="A token to use as initializer word." | 81 | help="A token to use as initializer word." |
82 | ) | 82 | ) |
83 | parser.add_argument( | 83 | parser.add_argument( |
@@ -316,6 +316,18 @@ def parse_args(): | |||
316 | if args.instance_identifier is None: | 316 | if args.instance_identifier is None: |
317 | raise ValueError("You must specify --instance_identifier") | 317 | raise ValueError("You must specify --instance_identifier") |
318 | 318 | ||
319 | if isinstance(args.initializer_token, str): | ||
320 | args.initializer_token = [args.initializer_token] | ||
321 | |||
322 | if isinstance(args.placeholder_token, str): | ||
323 | args.placeholder_token = [args.placeholder_token] | ||
324 | |||
325 | if len(args.placeholder_token) == 0: | ||
326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
327 | |||
328 | if len(args.placeholder_token) != len(args.initializer_token): | ||
329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | ||
330 | |||
319 | if args.output_dir is None: | 331 | if args.output_dir is None: |
320 | raise ValueError("You must specify --output_dir") | 332 | raise ValueError("You must specify --output_dir") |
321 | 333 | ||
@@ -379,9 +391,6 @@ class Checkpointer: | |||
379 | 391 | ||
380 | @torch.no_grad() | 392 | @torch.no_grad() |
381 | def save_embedding(self, step, postfix): | 393 | def save_embedding(self, step, postfix): |
382 | if self.placeholder_token_id is None: | ||
383 | return | ||
384 | |||
385 | print("Saving checkpoint for step %d..." % step) | 394 | print("Saving checkpoint for step %d..." % step) |
386 | 395 | ||
387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 396 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
@@ -389,12 +398,13 @@ class Checkpointer: | |||
389 | 398 | ||
390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 399 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
391 | 400 | ||
392 | # Save a checkpoint | 401 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 402 | # Save a checkpoint |
394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 403 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
404 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
395 | 405 | ||
396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
398 | 408 | ||
399 | del unwrapped | 409 | del unwrapped |
400 | del learned_embeds | 410 | del learned_embeds |
@@ -467,7 +477,7 @@ class Checkpointer: | |||
467 | for i in range(self.sample_batches): | 477 | for i in range(self.sample_batches): |
468 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 478 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
469 | prompt = [ | 479 | prompt = [ |
470 | prompt.format(self.instance_identifier) | 480 | prompt.format(identifier=self.instance_identifier) |
471 | for batch in batches | 481 | for batch in batches |
472 | for prompt in batch["prompts"] | 482 | for prompt in batch["prompts"] |
473 | ][:self.sample_batch_size] | 483 | ][:self.sample_batch_size] |
@@ -516,8 +526,8 @@ def main(): | |||
516 | 526 | ||
517 | instance_identifier = args.instance_identifier | 527 | instance_identifier = args.instance_identifier |
518 | 528 | ||
519 | if args.placeholder_token is not None: | 529 | if len(args.placeholder_token) != 0: |
520 | instance_identifier = instance_identifier.format(args.placeholder_token) | 530 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) |
521 | 531 | ||
522 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 532 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
523 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 533 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
@@ -565,18 +575,16 @@ def main(): | |||
565 | # Freeze text_encoder and vae | 575 | # Freeze text_encoder and vae |
566 | freeze_params(vae.parameters()) | 576 | freeze_params(vae.parameters()) |
567 | 577 | ||
568 | if args.initializer_token is not None: | 578 | if len(args.initializer_token) != 0: |
569 | # Convert the initializer_token, placeholder_token to ids | 579 | # Convert the initializer_token, placeholder_token to ids |
570 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 580 | initializer_token_ids = torch.stack([ |
571 | print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") | 581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
572 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 582 | for token in args.initializer_token |
583 | ]) | ||
573 | 584 | ||
574 | # Add the placeholder token in tokenizer | ||
575 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 585 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
576 | if num_added_tokens == 0: | 586 | print(f"Added {num_added_tokens} new tokens.") |
577 | print(f"Re-using existing token {args.placeholder_token}.") | 587 | |
578 | else: | ||
579 | print(f"Training new token {args.placeholder_token}.") | ||
580 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 588 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
581 | 589 | ||
582 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 590 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
@@ -586,7 +594,9 @@ def main(): | |||
586 | token_embeds = text_encoder.get_input_embeddings().weight.data | 594 | token_embeds = text_encoder.get_input_embeddings().weight.data |
587 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 595 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
588 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 596 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
589 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 597 | |
598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
599 | token_embeds[token_id] = embeddings | ||
590 | 600 | ||
591 | freeze_params(itertools.chain( | 601 | freeze_params(itertools.chain( |
592 | text_encoder.text_model.encoder.parameters(), | 602 | text_encoder.text_model.encoder.parameters(), |
@@ -594,7 +604,7 @@ def main(): | |||
594 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 604 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
595 | )) | 605 | )) |
596 | else: | 606 | else: |
597 | placeholder_token_id = None | 607 | placeholder_token_id = [] |
598 | 608 | ||
599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
600 | 610 | ||
@@ -721,7 +731,7 @@ def main(): | |||
721 | with torch.inference_mode(): | 731 | with torch.inference_mode(): |
722 | for batch in batched_data: | 732 | for batch in batched_data: |
723 | image_name = [item.class_image_path for item in batch] | 733 | image_name = [item.class_image_path for item in batch] |
724 | prompt = [item.prompt.format(args.class_identifier) for item in batch] | 734 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
725 | nprompt = [item.nprompt for item in batch] | 735 | nprompt = [item.nprompt for item in batch] |
726 | 736 | ||
727 | images = pipeline( | 737 | images = pipeline( |
@@ -787,7 +797,10 @@ def main(): | |||
787 | # We need to initialize the trackers we use, and also store our configuration. | 797 | # We need to initialize the trackers we use, and also store our configuration. |
788 | # The trackers initializes automatically on the main process. | 798 | # The trackers initializes automatically on the main process. |
789 | if accelerator.is_main_process: | 799 | if accelerator.is_main_process: |
790 | accelerator.init_trackers("dreambooth", config=vars(args)) | 800 | config = vars(args).copy() |
801 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
802 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
803 | accelerator.init_trackers("dreambooth", config=config) | ||
791 | 804 | ||
792 | # Train! | 805 | # Train! |
793 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 806 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
@@ -932,6 +945,9 @@ def main(): | |||
932 | global_step += 1 | 945 | global_step += 1 |
933 | 946 | ||
934 | if global_step % args.sample_frequency == 0: | 947 | if global_step % args.sample_frequency == 0: |
948 | local_progress_bar.clear() | ||
949 | global_progress_bar.clear() | ||
950 | |||
935 | checkpointer.save_embedding(global_step, "training") | 951 | checkpointer.save_embedding(global_step, "training") |
936 | sample_checkpoint = True | 952 | sample_checkpoint = True |
937 | 953 | ||
diff --git a/textual_inversion.py b/textual_inversion.py index c42762f..bcdfd3a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -70,13 +70,13 @@ def parse_args(): | |||
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--placeholder_token", | 71 | "--placeholder_token", |
72 | type=str, | 72 | type=str, |
73 | default="<*>", | 73 | nargs='*', |
74 | help="A token to use as a placeholder for the concept.", | 74 | help="A token to use as a placeholder for the concept.", |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--initializer_token", | 77 | "--initializer_token", |
78 | type=str, | 78 | type=str, |
79 | default=None, | 79 | nargs='*', |
80 | help="A token to use as initializer word." | 80 | help="A token to use as initializer word." |
81 | ) | 81 | ) |
82 | parser.add_argument( | 82 | parser.add_argument( |
@@ -299,12 +299,21 @@ def parse_args(): | |||
299 | if args.pretrained_model_name_or_path is None: | 299 | if args.pretrained_model_name_or_path is None: |
300 | raise ValueError("You must specify --pretrained_model_name_or_path") | 300 | raise ValueError("You must specify --pretrained_model_name_or_path") |
301 | 301 | ||
302 | if args.placeholder_token is None: | 302 | if isinstance(args.initializer_token, str): |
303 | raise ValueError("You must specify --placeholder_token") | 303 | args.initializer_token = [args.initializer_token] |
304 | 304 | ||
305 | if args.initializer_token is None: | 305 | if len(args.initializer_token) == 0: |
306 | raise ValueError("You must specify --initializer_token") | 306 | raise ValueError("You must specify --initializer_token") |
307 | 307 | ||
308 | if isinstance(args.placeholder_token, str): | ||
309 | args.placeholder_token = [args.placeholder_token] | ||
310 | |||
311 | if len(args.placeholder_token) == 0: | ||
312 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
313 | |||
314 | if len(args.placeholder_token) != len(args.initializer_token): | ||
315 | raise ValueError("You must specify --placeholder_token") | ||
316 | |||
308 | if args.output_dir is None: | 317 | if args.output_dir is None: |
309 | raise ValueError("You must specify --output_dir") | 318 | raise ValueError("You must specify --output_dir") |
310 | 319 | ||
@@ -373,12 +382,13 @@ class Checkpointer: | |||
373 | 382 | ||
374 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 383 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
375 | 384 | ||
376 | # Save a checkpoint | 385 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
377 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 386 | # Save a checkpoint |
378 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 387 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
388 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
379 | 389 | ||
380 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 390 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
381 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 391 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
382 | 392 | ||
383 | del unwrapped | 393 | del unwrapped |
384 | del learned_embeds | 394 | del learned_embeds |
@@ -422,7 +432,7 @@ class Checkpointer: | |||
422 | 432 | ||
423 | for i in range(self.sample_batches): | 433 | for i in range(self.sample_batches): |
424 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 434 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
425 | prompt = [prompt.format(self.instance_identifier) | 435 | prompt = [prompt.format(identifier=self.instance_identifier) |
426 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 436 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
427 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 437 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
428 | 438 | ||
@@ -498,16 +508,13 @@ def main(): | |||
498 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 508 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
499 | 509 | ||
500 | # Convert the initializer_token, placeholder_token to ids | 510 | # Convert the initializer_token, placeholder_token to ids |
501 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 511 | initializer_token_ids = torch.stack([ |
502 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | 512 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
503 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 513 | for token in args.initializer_token |
514 | ]) | ||
504 | 515 | ||
505 | # Add the placeholder token in tokenizer | ||
506 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 516 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
507 | if num_added_tokens == 0: | 517 | print(f"Added {num_added_tokens} new tokens.") |
508 | print(f"Re-using existing token {args.placeholder_token}.") | ||
509 | else: | ||
510 | print(f"Training new token {args.placeholder_token}.") | ||
511 | 518 | ||
512 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 519 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
513 | 520 | ||
@@ -533,11 +540,11 @@ def main(): | |||
533 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 540 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
534 | 541 | ||
535 | if args.resume_checkpoint is not None: | 542 | if args.resume_checkpoint is not None: |
536 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 543 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
537 | args.placeholder_token] | ||
538 | else: | 544 | else: |
539 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 545 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
540 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 546 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
547 | token_embeds[token_id] = embeddings | ||
541 | 548 | ||
542 | # Freeze vae and unet | 549 | # Freeze vae and unet |
543 | freeze_params(vae.parameters()) | 550 | freeze_params(vae.parameters()) |
@@ -648,7 +655,7 @@ def main(): | |||
648 | with torch.inference_mode(): | 655 | with torch.inference_mode(): |
649 | for batch in batched_data: | 656 | for batch in batched_data: |
650 | image_name = [p.class_image_path for p in batch] | 657 | image_name = [p.class_image_path for p in batch] |
651 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 658 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] |
652 | nprompt = [p.nprompt for p in batch] | 659 | nprompt = [p.nprompt for p in batch] |
653 | 660 | ||
654 | images = pipeline( | 661 | images = pipeline( |
@@ -716,7 +723,10 @@ def main(): | |||
716 | # We need to initialize the trackers we use, and also store our configuration. | 723 | # We need to initialize the trackers we use, and also store our configuration. |
717 | # The trackers initializes automatically on the main process. | 724 | # The trackers initializes automatically on the main process. |
718 | if accelerator.is_main_process: | 725 | if accelerator.is_main_process: |
719 | accelerator.init_trackers("textual_inversion", config=vars(args)) | 726 | config = vars(args).copy() |
727 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
728 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
729 | accelerator.init_trackers("textual_inversion", config=config) | ||
720 | 730 | ||
721 | # Train! | 731 | # Train! |
722 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 732 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |