From 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 13:38:43 +0100 Subject: Fixed aspect ratio bucketing; allow passing token IDs to pipeline --- train_ti.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 727b591..323ef10 100644 --- a/train_ti.py +++ b/train_ti.py @@ -140,13 +140,13 @@ def parse_args(): ), ) parser.add_argument( - "--num_aspect_ratio_buckets", + "--num_buckets", type=int, default=4, - help="Number of buckets in either direction (adds 64 pixels per step).", + help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", ) parser.add_argument( - "--progressive_aspect_ratio_buckets", + "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) @@ -681,9 +681,9 @@ def main(): return cond1 and cond3 and cond4 def collate_fn(examples): - prompts = [example["prompts"] for example in examples] - cprompts = [example["cprompts"] for example in examples] - nprompts = [example["nprompts"] for example in examples] + prompt_ids = [example["prompt_ids"] for example in examples] + nprompt_ids = [example["nprompt_ids"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -695,16 +695,18 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + prompts = prompt_processor.unify_input_ids(prompt_ids) + nprompts = prompt_processor.unify_input_ids(nprompt_ids) inputs = prompt_processor.unify_input_ids(input_ids) batch = { - "prompts": prompts, - "cprompts": cprompts, - "nprompts": nprompts, + "prompt_ids": prompts.input_ids, + "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } + return batch datamodule = VlpnDataModule( @@ -714,8 +716,8 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, - num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, - progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, -- cgit v1.2.3-54-g00ecf