From 1bd386f98bb076fe62696808e02a9bd9b9b64b42 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 21:47:12 +0100 Subject: Improved class prompt handling --- train_dreambooth.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 9749c62..ff67d12 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -62,16 +62,10 @@ def parse_args(): default="template", ) parser.add_argument( - "--instance_identifier", + "--project", type=str, default=None, - help="A token to use as a placeholder for the concept.", - ) - parser.add_argument( - "--class_identifier", - type=str, - default=None, - help="A token to use as a placeholder for the concept.", + help="The name of the current project.", ) parser.add_argument( "--placeholder_token", @@ -364,8 +358,8 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") - if args.instance_identifier is None: - raise ValueError("You must specify --instance_identifier") + if args.project is None: + raise ValueError("You must specify --project") if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] @@ -397,7 +391,6 @@ class Checkpointer(CheckpointerBase): text_encoder, scheduler, output_dir: Path, - instance_identifier, placeholder_token, placeholder_token_id, sample_image_size, @@ -408,7 +401,6 @@ class Checkpointer(CheckpointerBase): super().__init__( datamodule=datamodule, output_dir=output_dir, - instance_identifier=instance_identifier, placeholder_token=placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, @@ -481,13 +473,8 @@ def main(): "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." ) - instance_identifier = args.instance_identifier - - if len(args.placeholder_token) != 0: - instance_identifier = instance_identifier.format(args.placeholder_token[0]) - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -637,6 +624,7 @@ def main(): def collate_fn(examples): prompts = [example["prompts"] for example in examples] + cprompts = [example["cprompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -653,6 +641,7 @@ def main(): batch = { "prompts": prompts, + "cprompts": cprompts, "nprompts": nprompts, "input_ids": inputs.input_ids, "pixel_values": pixel_values, @@ -664,8 +653,6 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - instance_identifier=instance_identifier, - class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -703,7 +690,7 @@ def main(): with torch.autocast("cuda"), torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] + prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -812,7 +799,6 @@ def main(): text_encoder=text_encoder, scheduler=checkpoint_scheduler, output_dir=basepath, - instance_identifier=instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=args.sample_image_size, -- cgit v1.2.3-54-g00ecf