From 1bd386f98bb076fe62696808e02a9bd9b9b64b42 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 21:47:12 +0100 Subject: Improved class prompt handling --- data/csv.py | 29 ++++++++++++----------------- environment.yaml | 2 +- train_dreambooth.py | 30 ++++++++---------------------- train_ti.py | 29 +++++++++-------------------- training/util.py | 4 +--- 5 files changed, 31 insertions(+), 63 deletions(-) diff --git a/data/csv.py b/data/csv.py index d400757..e25dd3f 100644 --- a/data/csv.py +++ b/data/csv.py @@ -38,6 +38,7 @@ class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] + cprompt: str nprompt: str @@ -47,8 +48,6 @@ class CSVDataModule(): batch_size: int, data_file: str, prompt_processor: PromptProcessor, - instance_identifier: str, - class_identifier: Optional[str] = None, class_subdir: str = "cls", num_class_images: int = 100, size: int = 512, @@ -77,8 +76,6 @@ class CSVDataModule(): self.num_class_images = num_class_images self.prompt_processor = prompt_processor - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier self.size = size self.repeats = repeats self.dropout = dropout @@ -96,14 +93,18 @@ class CSVDataModule(): def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" + cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ CSVDataItem( self.data_root.joinpath(image.format(item["image"])), None, - prompt_to_keywords(prompt.format( - **prepare_prompt(item["prompt"] if "prompt" in item else "")), expansions), + prompt_to_keywords( + prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), + expansions + ), + cprompt.format(**prepare_prompt(item["cprompt"] if "cprompt" in item else "")), nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), ) for item in data @@ -123,6 +124,7 @@ class CSVDataModule(): item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, + item.cprompt, item.nprompt, ) for item in items @@ -160,12 +162,10 @@ class CSVDataModule(): def setup(self, stage=None): train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, - instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, num_class_images=self.num_class_images, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, - instance_identifier=self.instance_identifier, size=self.size, interpolation=self.interpolation, center_crop=self.center_crop) self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, @@ -187,9 +187,7 @@ class CSVDataset(Dataset): self, data: List[CSVDataItem], prompt_processor: PromptProcessor, - instance_identifier: str, batch_size: int = 1, - class_identifier: Optional[str] = None, num_class_images: int = 0, size: int = 512, repeats: int = 1, @@ -201,8 +199,6 @@ class CSVDataset(Dataset): self.data = data self.prompt_processor = prompt_processor self.batch_size = batch_size - self.instance_identifier = instance_identifier - self.class_identifier = class_identifier self.num_class_images = num_class_images self.dropout = dropout self.image_cache = {} @@ -239,14 +235,12 @@ class CSVDataset(Dataset): return image - def get_input_ids(self, prompt, identifier): - return self.prompt_processor.get_input_ids(prompt.format(identifier)) - def get_example(self, i): item = self.data[i % self.num_instance_images] example = {} example["prompts"] = item.prompt + example["cprompts"] = item.cprompt example["nprompts"] = item.nprompt example["instance_images"] = self.get_image(item.instance_image_path) if self.num_class_images != 0: @@ -260,13 +254,14 @@ class CSVDataset(Dataset): example = {} example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"], self.dropout) + example["cprompts"] = unprocessed_example["cprompts"] example["nprompts"] = unprocessed_example["nprompts"] example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(example["prompts"]) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) return example diff --git a/environment.yaml b/environment.yaml index e598f72..179fa38 100644 --- a/environment.yaml +++ b/environment.yaml @@ -21,4 +21,4 @@ dependencies: - test-tube>=0.7.5 - transformers==4.25.1 - triton==2.0.0.dev20221202 - - xformers==0.0.16rc391 + - xformers==0.0.16rc396 diff --git a/train_dreambooth.py b/train_dreambooth.py index 9749c62..ff67d12 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -62,16 +62,10 @@ def parse_args(): default="template", ) parser.add_argument( - "--instance_identifier", + "--project", type=str, default=None, - help="A token to use as a placeholder for the concept.", - ) - parser.add_argument( - "--class_identifier", - type=str, - default=None, - help="A token to use as a placeholder for the concept.", + help="The name of the current project.", ) parser.add_argument( "--placeholder_token", @@ -364,8 +358,8 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") - if args.instance_identifier is None: - raise ValueError("You must specify --instance_identifier") + if args.project is None: + raise ValueError("You must specify --project") if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] @@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase): text_encoder, scheduler, output_dir: Path, - instance_identifier, placeholder_token, placeholder_token_id, sample_image_size, @@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase): super().__init__( datamodule=datamodule, output_dir=output_dir, - instance_identifier=instance_identifier, placeholder_token=placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, @@ -481,13 +473,8 @@ def main(): "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) - instance_identifier = args.instance_identifier - - if len(args.placeholder_token) != 0: - instance_identifier = instance_identifier.format(args.placeholder_token[0]) - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -637,6 +624,7 @@ def main(): def collate_fn(examples): prompts = [example["prompts"] for example in examples] + cprompts = [example["cprompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -653,6 +641,7 @@ def main(): batch = { "prompts": prompts, + "cprompts": cprompts, "nprompts": nprompts, "input_ids": inputs.input_ids, "pixel_values": pixel_values, @@ -664,8 +653,6 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - instance_identifier=instance_identifier, - class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -703,7 +690,7 @@ def main(): with torch.autocast("cuda"), torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] + prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -812,7 +799,6 @@ def main(): text_encoder=text_encoder, scheduler=checkpoint_scheduler, output_dir=basepath, - instance_identifier=instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=args.sample_image_size, diff --git a/train_ti.py b/train_ti.py index 9d06c50..55daa35 100644 --- a/train_ti.py +++ b/train_ti.py @@ -63,16 +63,10 @@ def parse_args(): default="template", ) parser.add_argument( - "--instance_identifier", + "--project", type=str, default=None, - help="A token to use as a placeholder for the concept.", - ) - parser.add_argument( - "--class_identifier", - type=str, - default=None, - help="A token to use as a placeholder for the concept.", + help="The name of the current project.", ) parser.add_argument( "--placeholder_token", @@ -334,6 +328,9 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") + if args.project is None: + raise ValueError("You must specify --project") + if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] @@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase): text_encoder, scheduler, text_embeddings, - instance_identifier, placeholder_token, placeholder_token_id, output_dir: Path, @@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase): super().__init__( datamodule=datamodule, output_dir=output_dir, - instance_identifier=instance_identifier, placeholder_token=placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, @@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase): def main(): args = parse_args() - instance_identifier = args.instance_identifier - - if len(args.placeholder_token) != 0: - instance_identifier = instance_identifier.format(args.placeholder_token[0]) - global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -567,6 +557,7 @@ def main(): def collate_fn(examples): prompts = [example["prompts"] for example in examples] + cprompts = [example["cprompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -583,6 +574,7 @@ def main(): batch = { "prompts": prompts, + "cprompts": cprompts, "nprompts": nprompts, "input_ids": inputs.input_ids, "pixel_values": pixel_values, @@ -594,8 +586,6 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - instance_identifier=args.instance_identifier, - class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -634,7 +624,7 @@ def main(): with torch.autocast("cuda"), torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] + prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -744,7 +734,6 @@ def main(): text_encoder=text_encoder, scheduler=checkpoint_scheduler, text_embeddings=text_embeddings, - instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, diff --git a/training/util.py b/training/util.py index e8d22ae..000173d 100644 --- a/training/util.py +++ b/training/util.py @@ -44,7 +44,6 @@ class CheckpointerBase: self, datamodule, output_dir: Path, - instance_identifier, placeholder_token, placeholder_token_id, sample_image_size, @@ -54,7 +53,6 @@ class CheckpointerBase: ): self.datamodule = datamodule self.output_dir = output_dir - self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.sample_image_size = sample_image_size @@ -90,7 +88,7 @@ class CheckpointerBase: if j * data.batch_size < self.sample_batch_size * self.sample_batches ] prompts = [ - prompt.format(identifier=self.instance_identifier) + prompt for batch in batches for prompt in batch["prompts"] ] -- cgit v1.2.3-54-g00ecf