From baba91864a45939cef4f77f6ca96ade7ae5ef274 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 24 Oct 2022 23:46:18 +0200 Subject: Advanced datasets --- data/csv.py | 64 ++++++++++++++++++++++++++++++------------------- dreambooth.py | 68 ++++++++++++++++++++++++++++++++-------------------- textual_inversion.py | 58 +++++++++++++++++++++++++------------------- 3 files changed, 115 insertions(+), 75 deletions(-) diff --git a/data/csv.py b/data/csv.py index 5144c0a..f9b5e39 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,16 +1,20 @@ import math -import pandas as pd import torch +import json from pathlib import Path import pytorch_lightning as pl from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms -from typing import NamedTuple, List, Optional +from typing import Dict, NamedTuple, List, Optional, Union from models.clip.prompt import PromptProcessor +def prepare_prompt(prompt: Union[str, Dict[str, str]]): + return {"content": prompt} if isinstance(prompt, str) else prompt + + class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule): self.collate_fn = collate_fn self.batch_size = batch_size - def prepare_subdata(self, data, num_class_images=1): + def prepare_subdata(self, template, data, num_class_images=1): + image = template["image"] if "image" in template else "{}" + prompt = template["prompt"] if "prompt" in template else "{content}" + nprompt = template["nprompt"] if "nprompt" in template else "{content}" + image_multiplier = max(math.ceil(num_class_images / len(data)), 1) return [ CSVDataItem( - self.data_root.joinpath(item.image), - self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), - item.prompt, - item.nprompt + self.data_root.joinpath(image.format(item["image"])), + self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), + prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), + nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) ) for item in data for i in range(image_multiplier) ] def prepare_data(self): - metadata = pd.read_json(self.data_file) - metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] - num_images = len(metadata) + with open(self.data_file, 'rt') as f: + metadata = json.load(f) + template = metadata["template"] if "template" in metadata else {} + items = metadata["items"] if "items" in metadata else [] + + items = [item for item in items if not "skip" in item or item["skip"] != True] + num_images = len(items) valid_set_size = int(num_images * 0.2) if self.valid_set_size: @@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule): valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size - data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) + data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) - self.data_train = self.prepare_subdata(data_train, self.num_class_images) - self.data_val = self.prepare_subdata(data_val) + self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) + self.data_val = self.prepare_subdata(template, data_val) def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, @@ -133,8 +145,8 @@ class CSVDataset(Dataset): self.instance_identifier = instance_identifier self.class_identifier = class_identifier self.num_class_images = num_class_images - self.cache = {} self.image_cache = {} + self.input_id_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -168,12 +180,19 @@ class CSVDataset(Dataset): return image + def get_input_ids(self, prompt, identifier): + prompt = prompt.format(identifier) + + if prompt in self.input_id_cache: + return self.input_id_cache[prompt] + + input_ids = self.prompt_processor.get_input_ids(prompt) + self.input_id_cache[prompt] = input_ids + + return input_ids + def get_example(self, i): item = self.data[i % self.num_instance_images] - cache_key = f"{item.instance_image_path}_{item.class_image_path}" - - if cache_key in self.cache: - return self.cache[cache_key] example = {} @@ -181,17 +200,12 @@ class CSVDataset(Dataset): example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( - item.prompt.format(self.instance_identifier) - ) + example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids( - item.nprompt.format(self.class_identifier) - ) + example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) - self.cache[cache_key] = example return example def __getitem__(self, i): diff --git a/dreambooth.py b/dreambooth.py index 5c26f12..2c24908 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -71,13 +71,13 @@ def parse_args(): parser.add_argument( "--placeholder_token", type=str, - default="<*>", + nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, - default=None, + nargs='*', help="A token to use as initializer word." ) parser.add_argument( @@ -316,6 +316,18 @@ def parse_args(): if args.instance_identifier is None: raise ValueError("You must specify --instance_identifier") + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] + + if isinstance(args.placeholder_token, str): + args.placeholder_token = [args.placeholder_token] + + if len(args.placeholder_token) == 0: + args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + + if len(args.placeholder_token) != len(args.initializer_token): + raise ValueError("Number of items in --placeholder_token and --initializer_token must match") + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -379,9 +391,6 @@ class Checkpointer: @torch.no_grad() def save_embedding(self, step, postfix): - if self.placeholder_token_id is None: - return - print("Saving checkpoint for step %d..." % step) checkpoints_path = self.output_dir.joinpath("checkpoints") @@ -389,12 +398,13 @@ class Checkpointer: unwrapped = self.accelerator.unwrap_model(self.text_encoder) - # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] - learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] + learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} - filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) del unwrapped del learned_embeds @@ -467,7 +477,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] prompt = [ - prompt.format(self.instance_identifier) + prompt.format(identifier=self.instance_identifier) for batch in batches for prompt in batch["prompts"] ][:self.sample_batch_size] @@ -516,8 +526,8 @@ def main(): instance_identifier = args.instance_identifier - if args.placeholder_token is not None: - instance_identifier = instance_identifier.format(args.placeholder_token) + if len(args.placeholder_token) != 0: + instance_identifier = instance_identifier.format(args.placeholder_token[0]) now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) @@ -565,18 +575,16 @@ def main(): # Freeze text_encoder and vae freeze_params(vae.parameters()) - if args.initializer_token is not None: + if len(args.initializer_token) != 0: # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") - initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + initializer_token_ids = torch.stack([ + torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + for token in args.initializer_token + ]) - # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - if num_added_tokens == 0: - print(f"Re-using existing token {args.placeholder_token}.") - else: - print(f"Training new token {args.placeholder_token}.") + print(f"Added {num_added_tokens} new tokens.") + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Resize the token embeddings as we are adding new special tokens to the tokenizer @@ -586,7 +594,9 @@ def main(): token_embeds = text_encoder.get_input_embeddings().weight.data original_token_embeds = token_embeds.detach().clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - token_embeds[placeholder_token_id] = initializer_token_embeddings + + for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): + token_embeds[token_id] = embeddings freeze_params(itertools.chain( text_encoder.text_model.encoder.parameters(), @@ -594,7 +604,7 @@ def main(): text_encoder.text_model.embeddings.position_embedding.parameters(), )) else: - placeholder_token_id = None + placeholder_token_id = [] prompt_processor = PromptProcessor(tokenizer, text_encoder) @@ -721,7 +731,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [item.prompt.format(args.class_identifier) for item in batch] + prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -787,7 +797,10 @@ def main(): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("dreambooth", config=vars(args)) + config = vars(args).copy() + config["initializer_token"] = " ".join(config["initializer_token"]) + config["placeholder_token"] = " ".join(config["placeholder_token"]) + accelerator.init_trackers("dreambooth", config=config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -932,6 +945,9 @@ def main(): global_step += 1 if global_step % args.sample_frequency == 0: + local_progress_bar.clear() + global_progress_bar.clear() + checkpointer.save_embedding(global_step, "training") sample_checkpoint = True diff --git a/textual_inversion.py b/textual_inversion.py index c42762f..bcdfd3a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -70,13 +70,13 @@ def parse_args(): parser.add_argument( "--placeholder_token", type=str, - default="<*>", + nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, - default=None, + nargs='*', help="A token to use as initializer word." ) parser.add_argument( @@ -299,12 +299,21 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") - if args.placeholder_token is None: - raise ValueError("You must specify --placeholder_token") + if isinstance(args.initializer_token, str): + args.initializer_token = [args.initializer_token] - if args.initializer_token is None: + if len(args.initializer_token) == 0: raise ValueError("You must specify --initializer_token") + if isinstance(args.placeholder_token, str): + args.placeholder_token = [args.placeholder_token] + + if len(args.placeholder_token) == 0: + args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + + if len(args.placeholder_token) != len(args.initializer_token): + raise ValueError("You must specify --placeholder_token") + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -373,12 +382,13 @@ class Checkpointer: unwrapped = self.accelerator.unwrap_model(self.text_encoder) - # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] - learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] + learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} - filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) del unwrapped del learned_embeds @@ -422,7 +432,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.instance_identifier) + prompt = [prompt.format(identifier=self.instance_identifier) for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] @@ -498,16 +508,13 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") - initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + initializer_token_ids = torch.stack([ + torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + for token in args.initializer_token + ]) - # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - if num_added_tokens == 0: - print(f"Re-using existing token {args.placeholder_token}.") - else: - print(f"Training new token {args.placeholder_token}.") + print(f"Added {num_added_tokens} new tokens.") placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) @@ -533,11 +540,11 @@ def main(): original_token_embeds = token_embeds.detach().clone().to(accelerator.device) if args.resume_checkpoint is not None: - token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ - args.placeholder_token] + token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] else: initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) - token_embeds[placeholder_token_id] = initializer_token_embeddings + for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): + token_embeds[token_id] = embeddings # Freeze vae and unet freeze_params(vae.parameters()) @@ -648,7 +655,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.class_identifier) for p in batch] + prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] nprompt = [p.nprompt for p in batch] images = pipeline( @@ -716,7 +723,10 @@ def main(): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion", config=vars(args)) + config = vars(args).copy() + config["initializer_token"] = " ".join(config["initializer_token"]) + config["placeholder_token"] = " ".join(config["placeholder_token"]) + accelerator.init_trackers("textual_inversion", config=config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps -- cgit v1.2.3-54-g00ecf