summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py24
1 files changed, 13 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py
index 727b591..323ef10 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -140,13 +140,13 @@ def parse_args():
140 ), 140 ),
141 ) 141 )
142 parser.add_argument( 142 parser.add_argument(
143 "--num_aspect_ratio_buckets", 143 "--num_buckets",
144 type=int, 144 type=int,
145 default=4, 145 default=4,
146 help="Number of buckets in either direction (adds 64 pixels per step).", 146 help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).",
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--progressive_aspect_ratio_buckets", 149 "--progressive_buckets",
150 action="store_true", 150 action="store_true",
151 help="Include images in smaller buckets as well.", 151 help="Include images in smaller buckets as well.",
152 ) 152 )
@@ -681,9 +681,9 @@ def main():
681 return cond1 and cond3 and cond4 681 return cond1 and cond3 and cond4
682 682
683 def collate_fn(examples): 683 def collate_fn(examples):
684 prompts = [example["prompts"] for example in examples] 684 prompt_ids = [example["prompt_ids"] for example in examples]
685 cprompts = [example["cprompts"] for example in examples] 685 nprompt_ids = [example["nprompt_ids"] for example in examples]
686 nprompts = [example["nprompts"] for example in examples] 686
687 input_ids = [example["instance_prompt_ids"] for example in examples] 687 input_ids = [example["instance_prompt_ids"] for example in examples]
688 pixel_values = [example["instance_images"] for example in examples] 688 pixel_values = [example["instance_images"] for example in examples]
689 689
@@ -695,16 +695,18 @@ def main():
695 pixel_values = torch.stack(pixel_values) 695 pixel_values = torch.stack(pixel_values)
696 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 696 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
697 697
698 prompts = prompt_processor.unify_input_ids(prompt_ids)
699 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
698 inputs = prompt_processor.unify_input_ids(input_ids) 700 inputs = prompt_processor.unify_input_ids(input_ids)
699 701
700 batch = { 702 batch = {
701 "prompts": prompts, 703 "prompt_ids": prompts.input_ids,
702 "cprompts": cprompts, 704 "nprompt_ids": nprompts.input_ids,
703 "nprompts": nprompts,
704 "input_ids": inputs.input_ids, 705 "input_ids": inputs.input_ids,
705 "pixel_values": pixel_values, 706 "pixel_values": pixel_values,
706 "attention_mask": inputs.attention_mask, 707 "attention_mask": inputs.attention_mask,
707 } 708 }
709
708 return batch 710 return batch
709 711
710 datamodule = VlpnDataModule( 712 datamodule = VlpnDataModule(
@@ -714,8 +716,8 @@ def main():
714 class_subdir=args.class_image_dir, 716 class_subdir=args.class_image_dir,
715 num_class_images=args.num_class_images, 717 num_class_images=args.num_class_images,
716 size=args.resolution, 718 size=args.resolution,
717 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, 719 num_buckets=args.num_buckets,
718 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, 720 progressive_buckets=args.progressive_buckets,
719 dropout=args.tag_dropout, 721 dropout=args.tag_dropout,
720 template_key=args.train_data_template, 722 template_key=args.train_data_template,
721 valid_set_size=args.valid_set_size, 723 valid_set_size=args.valid_set_size,