From 1bd386f98bb076fe62696808e02a9bd9b9b64b42 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 23 Dec 2022 21:47:12 +0100 Subject: Improved class prompt handling --- train_ti.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 9d06c50..55daa35 100644 --- a/train_ti.py +++ b/train_ti.py @@ -63,16 +63,10 @@ def parse_args(): default="template", ) parser.add_argument( - "--instance_identifier", + "--project", type=str, default=None, - help="A token to use as a placeholder for the concept.", - ) - parser.add_argument( - "--class_identifier", - type=str, - default=None, - help="A token to use as a placeholder for the concept.", + help="The name of the current project.", ) parser.add_argument( "--placeholder_token", @@ -334,6 +328,9 @@ def parse_args(): if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") + if args.project is None: + raise ValueError("You must specify --project") + if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] @@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase): text_encoder, scheduler, text_embeddings, - instance_identifier, placeholder_token, placeholder_token_id, output_dir: Path, @@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase): super().__init__( datamodule=datamodule, output_dir=output_dir, - instance_identifier=instance_identifier, placeholder_token=placeholder_token, placeholder_token_id=placeholder_token_id, sample_image_size=sample_image_size, @@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase): def main(): args = parse_args() - instance_identifier = args.instance_identifier - - if len(args.placeholder_token) != 0: - instance_identifier = instance_identifier.format(args.placeholder_token[0]) - global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(args.project), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -567,6 +557,7 @@ def main(): def collate_fn(examples): prompts = [example["prompts"] for example in examples] + cprompts = [example["cprompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -583,6 +574,7 @@ def main(): batch = { "prompts": prompts, + "cprompts": cprompts, "nprompts": nprompts, "input_ids": inputs.input_ids, "pixel_values": pixel_values, @@ -594,8 +586,6 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - instance_identifier=args.instance_identifier, - class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -634,7 +624,7 @@ def main(): with torch.autocast("cuda"), torch.inference_mode(): for batch in batched_data: image_name = [item.class_image_path for item in batch] - prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] + prompt = [item.cprompt for item in batch] nprompt = [item.nprompt for item in batch] images = pipeline( @@ -744,7 +734,6 @@ def main(): text_encoder=text_encoder, scheduler=checkpoint_scheduler, text_embeddings=text_embeddings, - instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, -- cgit v1.2.3-54-g00ecf