diff options
| -rw-r--r-- | data/csv.py | 64 | ||||
| -rw-r--r-- | dreambooth.py | 68 | ||||
| -rw-r--r-- | textual_inversion.py | 58 |
3 files changed, 115 insertions, 75 deletions
diff --git a/data/csv.py b/data/csv.py index 5144c0a..f9b5e39 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,16 +1,20 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import pandas as pd | ||
| 3 | import torch | 2 | import torch |
| 3 | import json | ||
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
| 6 | from PIL import Image | 6 | from PIL import Image |
| 7 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
| 8 | from torchvision import transforms | 8 | from torchvision import transforms |
| 9 | from typing import NamedTuple, List, Optional | 9 | from typing import Dict, NamedTuple, List, Optional, Union |
| 10 | 10 | ||
| 11 | from models.clip.prompt import PromptProcessor | 11 | from models.clip.prompt import PromptProcessor |
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | ||
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | ||
| 16 | |||
| 17 | |||
| 14 | class CSVDataItem(NamedTuple): | 18 | class CSVDataItem(NamedTuple): |
| 15 | instance_image_path: Path | 19 | instance_image_path: Path |
| 16 | class_image_path: Path | 20 | class_image_path: Path |
| @@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 60 | self.collate_fn = collate_fn | 64 | self.collate_fn = collate_fn |
| 61 | self.batch_size = batch_size | 65 | self.batch_size = batch_size |
| 62 | 66 | ||
| 63 | def prepare_subdata(self, data, num_class_images=1): | 67 | def prepare_subdata(self, template, data, num_class_images=1): |
| 68 | image = template["image"] if "image" in template else "{}" | ||
| 69 | prompt = template["prompt"] if "prompt" in template else "{content}" | ||
| 70 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | ||
| 71 | |||
| 64 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | 72 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) |
| 65 | 73 | ||
| 66 | return [ | 74 | return [ |
| 67 | CSVDataItem( | 75 | CSVDataItem( |
| 68 | self.data_root.joinpath(item.image), | 76 | self.data_root.joinpath(image.format(item["image"])), |
| 69 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | 77 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), |
| 70 | item.prompt, | 78 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 71 | item.nprompt | 79 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
| 72 | ) | 80 | ) |
| 73 | for item in data | 81 | for item in data |
| 74 | for i in range(image_multiplier) | 82 | for i in range(image_multiplier) |
| 75 | ] | 83 | ] |
| 76 | 84 | ||
| 77 | def prepare_data(self): | 85 | def prepare_data(self): |
| 78 | metadata = pd.read_json(self.data_file) | 86 | with open(self.data_file, 'rt') as f: |
| 79 | metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] | 87 | metadata = json.load(f) |
| 80 | num_images = len(metadata) | 88 | template = metadata["template"] if "template" in metadata else {} |
| 89 | items = metadata["items"] if "items" in metadata else [] | ||
| 90 | |||
| 91 | items = [item for item in items if not "skip" in item or item["skip"] != True] | ||
| 92 | num_images = len(items) | ||
| 81 | 93 | ||
| 82 | valid_set_size = int(num_images * 0.2) | 94 | valid_set_size = int(num_images * 0.2) |
| 83 | if self.valid_set_size: | 95 | if self.valid_set_size: |
| @@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 85 | valid_set_size = max(valid_set_size, 1) | 97 | valid_set_size = max(valid_set_size, 1) |
| 86 | train_set_size = num_images - valid_set_size | 98 | train_set_size = num_images - valid_set_size |
| 87 | 99 | ||
| 88 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) | 100 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
| 89 | 101 | ||
| 90 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) | 102 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) |
| 91 | self.data_val = self.prepare_subdata(data_val) | 103 | self.data_val = self.prepare_subdata(template, data_val) |
| 92 | 104 | ||
| 93 | def setup(self, stage=None): | 105 | def setup(self, stage=None): |
| 94 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 106 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| @@ -133,8 +145,8 @@ class CSVDataset(Dataset): | |||
| 133 | self.instance_identifier = instance_identifier | 145 | self.instance_identifier = instance_identifier |
| 134 | self.class_identifier = class_identifier | 146 | self.class_identifier = class_identifier |
| 135 | self.num_class_images = num_class_images | 147 | self.num_class_images = num_class_images |
| 136 | self.cache = {} | ||
| 137 | self.image_cache = {} | 148 | self.image_cache = {} |
| 149 | self.input_id_cache = {} | ||
| 138 | 150 | ||
| 139 | self.num_instance_images = len(self.data) | 151 | self.num_instance_images = len(self.data) |
| 140 | self._length = self.num_instance_images * repeats | 152 | self._length = self.num_instance_images * repeats |
| @@ -168,12 +180,19 @@ class CSVDataset(Dataset): | |||
| 168 | 180 | ||
| 169 | return image | 181 | return image |
| 170 | 182 | ||
| 183 | def get_input_ids(self, prompt, identifier): | ||
| 184 | prompt = prompt.format(identifier) | ||
| 185 | |||
| 186 | if prompt in self.input_id_cache: | ||
| 187 | return self.input_id_cache[prompt] | ||
| 188 | |||
| 189 | input_ids = self.prompt_processor.get_input_ids(prompt) | ||
| 190 | self.input_id_cache[prompt] = input_ids | ||
| 191 | |||
| 192 | return input_ids | ||
| 193 | |||
| 171 | def get_example(self, i): | 194 | def get_example(self, i): |
| 172 | item = self.data[i % self.num_instance_images] | 195 | item = self.data[i % self.num_instance_images] |
| 173 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" | ||
| 174 | |||
| 175 | if cache_key in self.cache: | ||
| 176 | return self.cache[cache_key] | ||
| 177 | 196 | ||
| 178 | example = {} | 197 | example = {} |
| 179 | 198 | ||
| @@ -181,17 +200,12 @@ class CSVDataset(Dataset): | |||
| 181 | example["nprompts"] = item.nprompt | 200 | example["nprompts"] = item.nprompt |
| 182 | 201 | ||
| 183 | example["instance_images"] = self.get_image(item.instance_image_path) | 202 | example["instance_images"] = self.get_image(item.instance_image_path) |
| 184 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 203 | example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) |
| 185 | item.prompt.format(self.instance_identifier) | ||
| 186 | ) | ||
| 187 | 204 | ||
| 188 | if self.num_class_images != 0: | 205 | if self.num_class_images != 0: |
| 189 | example["class_images"] = self.get_image(item.class_image_path) | 206 | example["class_images"] = self.get_image(item.class_image_path) |
| 190 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids( | 207 | example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) |
| 191 | item.nprompt.format(self.class_identifier) | ||
| 192 | ) | ||
| 193 | 208 | ||
| 194 | self.cache[cache_key] = example | ||
| 195 | return example | 209 | return example |
| 196 | 210 | ||
| 197 | def __getitem__(self, i): | 211 | def __getitem__(self, i): |
diff --git a/dreambooth.py b/dreambooth.py index 5c26f12..2c24908 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -71,13 +71,13 @@ def parse_args(): | |||
| 71 | parser.add_argument( | 71 | parser.add_argument( |
| 72 | "--placeholder_token", | 72 | "--placeholder_token", |
| 73 | type=str, | 73 | type=str, |
| 74 | default="<*>", | 74 | nargs='*', |
| 75 | help="A token to use as a placeholder for the concept.", | 75 | help="A token to use as a placeholder for the concept.", |
| 76 | ) | 76 | ) |
| 77 | parser.add_argument( | 77 | parser.add_argument( |
| 78 | "--initializer_token", | 78 | "--initializer_token", |
| 79 | type=str, | 79 | type=str, |
| 80 | default=None, | 80 | nargs='*', |
| 81 | help="A token to use as initializer word." | 81 | help="A token to use as initializer word." |
| 82 | ) | 82 | ) |
| 83 | parser.add_argument( | 83 | parser.add_argument( |
| @@ -316,6 +316,18 @@ def parse_args(): | |||
| 316 | if args.instance_identifier is None: | 316 | if args.instance_identifier is None: |
| 317 | raise ValueError("You must specify --instance_identifier") | 317 | raise ValueError("You must specify --instance_identifier") |
| 318 | 318 | ||
| 319 | if isinstance(args.initializer_token, str): | ||
| 320 | args.initializer_token = [args.initializer_token] | ||
| 321 | |||
| 322 | if isinstance(args.placeholder_token, str): | ||
| 323 | args.placeholder_token = [args.placeholder_token] | ||
| 324 | |||
| 325 | if len(args.placeholder_token) == 0: | ||
| 326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
| 327 | |||
| 328 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | ||
| 330 | |||
| 319 | if args.output_dir is None: | 331 | if args.output_dir is None: |
| 320 | raise ValueError("You must specify --output_dir") | 332 | raise ValueError("You must specify --output_dir") |
| 321 | 333 | ||
| @@ -379,9 +391,6 @@ class Checkpointer: | |||
| 379 | 391 | ||
| 380 | @torch.no_grad() | 392 | @torch.no_grad() |
| 381 | def save_embedding(self, step, postfix): | 393 | def save_embedding(self, step, postfix): |
| 382 | if self.placeholder_token_id is None: | ||
| 383 | return | ||
| 384 | |||
| 385 | print("Saving checkpoint for step %d..." % step) | 394 | print("Saving checkpoint for step %d..." % step) |
| 386 | 395 | ||
| 387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 396 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| @@ -389,12 +398,13 @@ class Checkpointer: | |||
| 389 | 398 | ||
| 390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 399 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 391 | 400 | ||
| 392 | # Save a checkpoint | 401 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 402 | # Save a checkpoint |
| 394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 403 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
| 404 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
| 395 | 405 | ||
| 396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| 397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 398 | 408 | ||
| 399 | del unwrapped | 409 | del unwrapped |
| 400 | del learned_embeds | 410 | del learned_embeds |
| @@ -467,7 +477,7 @@ class Checkpointer: | |||
| 467 | for i in range(self.sample_batches): | 477 | for i in range(self.sample_batches): |
| 468 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 478 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 469 | prompt = [ | 479 | prompt = [ |
| 470 | prompt.format(self.instance_identifier) | 480 | prompt.format(identifier=self.instance_identifier) |
| 471 | for batch in batches | 481 | for batch in batches |
| 472 | for prompt in batch["prompts"] | 482 | for prompt in batch["prompts"] |
| 473 | ][:self.sample_batch_size] | 483 | ][:self.sample_batch_size] |
| @@ -516,8 +526,8 @@ def main(): | |||
| 516 | 526 | ||
| 517 | instance_identifier = args.instance_identifier | 527 | instance_identifier = args.instance_identifier |
| 518 | 528 | ||
| 519 | if args.placeholder_token is not None: | 529 | if len(args.placeholder_token) != 0: |
| 520 | instance_identifier = instance_identifier.format(args.placeholder_token) | 530 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) |
| 521 | 531 | ||
| 522 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 532 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 523 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 533 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| @@ -565,18 +575,16 @@ def main(): | |||
| 565 | # Freeze text_encoder and vae | 575 | # Freeze text_encoder and vae |
| 566 | freeze_params(vae.parameters()) | 576 | freeze_params(vae.parameters()) |
| 567 | 577 | ||
| 568 | if args.initializer_token is not None: | 578 | if len(args.initializer_token) != 0: |
| 569 | # Convert the initializer_token, placeholder_token to ids | 579 | # Convert the initializer_token, placeholder_token to ids |
| 570 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 580 | initializer_token_ids = torch.stack([ |
| 571 | print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") | 581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| 572 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 582 | for token in args.initializer_token |
| 583 | ]) | ||
| 573 | 584 | ||
| 574 | # Add the placeholder token in tokenizer | ||
| 575 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 585 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 576 | if num_added_tokens == 0: | 586 | print(f"Added {num_added_tokens} new tokens.") |
| 577 | print(f"Re-using existing token {args.placeholder_token}.") | 587 | |
| 578 | else: | ||
| 579 | print(f"Training new token {args.placeholder_token}.") | ||
| 580 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 588 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 581 | 589 | ||
| 582 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 590 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| @@ -586,7 +594,9 @@ def main(): | |||
| 586 | token_embeds = text_encoder.get_input_embeddings().weight.data | 594 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 587 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 595 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 588 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 596 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 589 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 597 | |
| 598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 599 | token_embeds[token_id] = embeddings | ||
| 590 | 600 | ||
| 591 | freeze_params(itertools.chain( | 601 | freeze_params(itertools.chain( |
| 592 | text_encoder.text_model.encoder.parameters(), | 602 | text_encoder.text_model.encoder.parameters(), |
| @@ -594,7 +604,7 @@ def main(): | |||
| 594 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 604 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 595 | )) | 605 | )) |
| 596 | else: | 606 | else: |
| 597 | placeholder_token_id = None | 607 | placeholder_token_id = [] |
| 598 | 608 | ||
| 599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 600 | 610 | ||
| @@ -721,7 +731,7 @@ def main(): | |||
| 721 | with torch.inference_mode(): | 731 | with torch.inference_mode(): |
| 722 | for batch in batched_data: | 732 | for batch in batched_data: |
| 723 | image_name = [item.class_image_path for item in batch] | 733 | image_name = [item.class_image_path for item in batch] |
| 724 | prompt = [item.prompt.format(args.class_identifier) for item in batch] | 734 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
| 725 | nprompt = [item.nprompt for item in batch] | 735 | nprompt = [item.nprompt for item in batch] |
| 726 | 736 | ||
| 727 | images = pipeline( | 737 | images = pipeline( |
| @@ -787,7 +797,10 @@ def main(): | |||
| 787 | # We need to initialize the trackers we use, and also store our configuration. | 797 | # We need to initialize the trackers we use, and also store our configuration. |
| 788 | # The trackers initializes automatically on the main process. | 798 | # The trackers initializes automatically on the main process. |
| 789 | if accelerator.is_main_process: | 799 | if accelerator.is_main_process: |
| 790 | accelerator.init_trackers("dreambooth", config=vars(args)) | 800 | config = vars(args).copy() |
| 801 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
| 802 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
| 803 | accelerator.init_trackers("dreambooth", config=config) | ||
| 791 | 804 | ||
| 792 | # Train! | 805 | # Train! |
| 793 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 806 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| @@ -932,6 +945,9 @@ def main(): | |||
| 932 | global_step += 1 | 945 | global_step += 1 |
| 933 | 946 | ||
| 934 | if global_step % args.sample_frequency == 0: | 947 | if global_step % args.sample_frequency == 0: |
| 948 | local_progress_bar.clear() | ||
| 949 | global_progress_bar.clear() | ||
| 950 | |||
| 935 | checkpointer.save_embedding(global_step, "training") | 951 | checkpointer.save_embedding(global_step, "training") |
| 936 | sample_checkpoint = True | 952 | sample_checkpoint = True |
| 937 | 953 | ||
diff --git a/textual_inversion.py b/textual_inversion.py index c42762f..bcdfd3a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -70,13 +70,13 @@ def parse_args(): | |||
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--placeholder_token", | 71 | "--placeholder_token", |
| 72 | type=str, | 72 | type=str, |
| 73 | default="<*>", | 73 | nargs='*', |
| 74 | help="A token to use as a placeholder for the concept.", | 74 | help="A token to use as a placeholder for the concept.", |
| 75 | ) | 75 | ) |
| 76 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--initializer_token", | 77 | "--initializer_token", |
| 78 | type=str, | 78 | type=str, |
| 79 | default=None, | 79 | nargs='*', |
| 80 | help="A token to use as initializer word." | 80 | help="A token to use as initializer word." |
| 81 | ) | 81 | ) |
| 82 | parser.add_argument( | 82 | parser.add_argument( |
| @@ -299,12 +299,21 @@ def parse_args(): | |||
| 299 | if args.pretrained_model_name_or_path is None: | 299 | if args.pretrained_model_name_or_path is None: |
| 300 | raise ValueError("You must specify --pretrained_model_name_or_path") | 300 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| 301 | 301 | ||
| 302 | if args.placeholder_token is None: | 302 | if isinstance(args.initializer_token, str): |
| 303 | raise ValueError("You must specify --placeholder_token") | 303 | args.initializer_token = [args.initializer_token] |
| 304 | 304 | ||
| 305 | if args.initializer_token is None: | 305 | if len(args.initializer_token) == 0: |
| 306 | raise ValueError("You must specify --initializer_token") | 306 | raise ValueError("You must specify --initializer_token") |
| 307 | 307 | ||
| 308 | if isinstance(args.placeholder_token, str): | ||
| 309 | args.placeholder_token = [args.placeholder_token] | ||
| 310 | |||
| 311 | if len(args.placeholder_token) == 0: | ||
| 312 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
| 313 | |||
| 314 | if len(args.placeholder_token) != len(args.initializer_token): | ||
| 315 | raise ValueError("You must specify --placeholder_token") | ||
| 316 | |||
| 308 | if args.output_dir is None: | 317 | if args.output_dir is None: |
| 309 | raise ValueError("You must specify --output_dir") | 318 | raise ValueError("You must specify --output_dir") |
| 310 | 319 | ||
| @@ -373,12 +382,13 @@ class Checkpointer: | |||
| 373 | 382 | ||
| 374 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 383 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 375 | 384 | ||
| 376 | # Save a checkpoint | 385 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 377 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 386 | # Save a checkpoint |
| 378 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 387 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
| 388 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
| 379 | 389 | ||
| 380 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 390 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
| 381 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 391 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 382 | 392 | ||
| 383 | del unwrapped | 393 | del unwrapped |
| 384 | del learned_embeds | 394 | del learned_embeds |
| @@ -422,7 +432,7 @@ class Checkpointer: | |||
| 422 | 432 | ||
| 423 | for i in range(self.sample_batches): | 433 | for i in range(self.sample_batches): |
| 424 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 434 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 425 | prompt = [prompt.format(self.instance_identifier) | 435 | prompt = [prompt.format(identifier=self.instance_identifier) |
| 426 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 436 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 427 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 437 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 428 | 438 | ||
| @@ -498,16 +508,13 @@ def main(): | |||
| 498 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 508 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 499 | 509 | ||
| 500 | # Convert the initializer_token, placeholder_token to ids | 510 | # Convert the initializer_token, placeholder_token to ids |
| 501 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 511 | initializer_token_ids = torch.stack([ |
| 502 | print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") | 512 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| 503 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 513 | for token in args.initializer_token |
| 514 | ]) | ||
| 504 | 515 | ||
| 505 | # Add the placeholder token in tokenizer | ||
| 506 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 516 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| 507 | if num_added_tokens == 0: | 517 | print(f"Added {num_added_tokens} new tokens.") |
| 508 | print(f"Re-using existing token {args.placeholder_token}.") | ||
| 509 | else: | ||
| 510 | print(f"Training new token {args.placeholder_token}.") | ||
| 511 | 518 | ||
| 512 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 519 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 513 | 520 | ||
| @@ -533,11 +540,11 @@ def main(): | |||
| 533 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 540 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 534 | 541 | ||
| 535 | if args.resume_checkpoint is not None: | 542 | if args.resume_checkpoint is not None: |
| 536 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 543 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
| 537 | args.placeholder_token] | ||
| 538 | else: | 544 | else: |
| 539 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 545 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 540 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 546 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
| 547 | token_embeds[token_id] = embeddings | ||
| 541 | 548 | ||
| 542 | # Freeze vae and unet | 549 | # Freeze vae and unet |
| 543 | freeze_params(vae.parameters()) | 550 | freeze_params(vae.parameters()) |
| @@ -648,7 +655,7 @@ def main(): | |||
| 648 | with torch.inference_mode(): | 655 | with torch.inference_mode(): |
| 649 | for batch in batched_data: | 656 | for batch in batched_data: |
| 650 | image_name = [p.class_image_path for p in batch] | 657 | image_name = [p.class_image_path for p in batch] |
| 651 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 658 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] |
| 652 | nprompt = [p.nprompt for p in batch] | 659 | nprompt = [p.nprompt for p in batch] |
| 653 | 660 | ||
| 654 | images = pipeline( | 661 | images = pipeline( |
| @@ -716,7 +723,10 @@ def main(): | |||
| 716 | # We need to initialize the trackers we use, and also store our configuration. | 723 | # We need to initialize the trackers we use, and also store our configuration. |
| 717 | # The trackers initializes automatically on the main process. | 724 | # The trackers initializes automatically on the main process. |
| 718 | if accelerator.is_main_process: | 725 | if accelerator.is_main_process: |
| 719 | accelerator.init_trackers("textual_inversion", config=vars(args)) | 726 | config = vars(args).copy() |
| 727 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
| 728 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
| 729 | accelerator.init_trackers("textual_inversion", config=config) | ||
| 720 | 730 | ||
| 721 | # Train! | 731 | # Train! |
| 722 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 732 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
