diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 13:38:43 +0100 |
commit | 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch) | |
tree | d275e13506ca737efef18dc6dffa05f4e0d6759f /train_ti.py | |
parent | Improved aspect ratio bucketing (diff) | |
download | textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2 textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip |
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/train_ti.py b/train_ti.py index 727b591..323ef10 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -140,13 +140,13 @@ def parse_args(): | |||
140 | ), | 140 | ), |
141 | ) | 141 | ) |
142 | parser.add_argument( | 142 | parser.add_argument( |
143 | "--num_aspect_ratio_buckets", | 143 | "--num_buckets", |
144 | type=int, | 144 | type=int, |
145 | default=4, | 145 | default=4, |
146 | help="Number of buckets in either direction (adds 64 pixels per step).", | 146 | help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--progressive_aspect_ratio_buckets", | 149 | "--progressive_buckets", |
150 | action="store_true", | 150 | action="store_true", |
151 | help="Include images in smaller buckets as well.", | 151 | help="Include images in smaller buckets as well.", |
152 | ) | 152 | ) |
@@ -681,9 +681,9 @@ def main(): | |||
681 | return cond1 and cond3 and cond4 | 681 | return cond1 and cond3 and cond4 |
682 | 682 | ||
683 | def collate_fn(examples): | 683 | def collate_fn(examples): |
684 | prompts = [example["prompts"] for example in examples] | 684 | prompt_ids = [example["prompt_ids"] for example in examples] |
685 | cprompts = [example["cprompts"] for example in examples] | 685 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
686 | nprompts = [example["nprompts"] for example in examples] | 686 | |
687 | input_ids = [example["instance_prompt_ids"] for example in examples] | 687 | input_ids = [example["instance_prompt_ids"] for example in examples] |
688 | pixel_values = [example["instance_images"] for example in examples] | 688 | pixel_values = [example["instance_images"] for example in examples] |
689 | 689 | ||
@@ -695,16 +695,18 @@ def main(): | |||
695 | pixel_values = torch.stack(pixel_values) | 695 | pixel_values = torch.stack(pixel_values) |
696 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 696 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
697 | 697 | ||
698 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
699 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
698 | inputs = prompt_processor.unify_input_ids(input_ids) | 700 | inputs = prompt_processor.unify_input_ids(input_ids) |
699 | 701 | ||
700 | batch = { | 702 | batch = { |
701 | "prompts": prompts, | 703 | "prompt_ids": prompts.input_ids, |
702 | "cprompts": cprompts, | 704 | "nprompt_ids": nprompts.input_ids, |
703 | "nprompts": nprompts, | ||
704 | "input_ids": inputs.input_ids, | 705 | "input_ids": inputs.input_ids, |
705 | "pixel_values": pixel_values, | 706 | "pixel_values": pixel_values, |
706 | "attention_mask": inputs.attention_mask, | 707 | "attention_mask": inputs.attention_mask, |
707 | } | 708 | } |
709 | |||
708 | return batch | 710 | return batch |
709 | 711 | ||
710 | datamodule = VlpnDataModule( | 712 | datamodule = VlpnDataModule( |
@@ -714,8 +716,8 @@ def main(): | |||
714 | class_subdir=args.class_image_dir, | 716 | class_subdir=args.class_image_dir, |
715 | num_class_images=args.num_class_images, | 717 | num_class_images=args.num_class_images, |
716 | size=args.resolution, | 718 | size=args.resolution, |
717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 719 | num_buckets=args.num_buckets, |
718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 720 | progressive_buckets=args.progressive_buckets, |
719 | dropout=args.tag_dropout, | 721 | dropout=args.tag_dropout, |
720 | template_key=args.train_data_template, | 722 | template_key=args.train_data_template, |
721 | valid_set_size=args.valid_set_size, | 723 | valid_set_size=args.valid_set_size, |