summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 13:38:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 13:38:43 +0100
commit7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch)
treed275e13506ca737efef18dc6dffa05f4e0d6759f /train_dreambooth.py
parentImproved aspect ratio bucketing (diff)
downloadtextual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz
textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2
textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 42a7d0f..79eede6 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -699,9 +699,9 @@ def main():
699 return cond3 and cond4 699 return cond3 and cond4
700 700
701 def collate_fn(examples): 701 def collate_fn(examples):
702 prompts = [example["prompts"] for example in examples] 702 prompt_ids = [example["prompt_ids"] for example in examples]
703 cprompts = [example["cprompts"] for example in examples] 703 nprompt_ids = [example["nprompt_ids"] for example in examples]
704 nprompts = [example["nprompts"] for example in examples] 704
705 input_ids = [example["instance_prompt_ids"] for example in examples] 705 input_ids = [example["instance_prompt_ids"] for example in examples]
706 pixel_values = [example["instance_images"] for example in examples] 706 pixel_values = [example["instance_images"] for example in examples]
707 707
@@ -713,16 +713,18 @@ def main():
713 pixel_values = torch.stack(pixel_values) 713 pixel_values = torch.stack(pixel_values)
714 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 714 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
715 715
716 prompts = prompt_processor.unify_input_ids(prompt_ids)
717 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
716 inputs = prompt_processor.unify_input_ids(input_ids) 718 inputs = prompt_processor.unify_input_ids(input_ids)
717 719
718 batch = { 720 batch = {
719 "prompts": prompts, 721 "prompt_ids": prompts.input_ids,
720 "cprompts": cprompts, 722 "nprompt_ids": nprompts.input_ids,
721 "nprompts": nprompts,
722 "input_ids": inputs.input_ids, 723 "input_ids": inputs.input_ids,
723 "pixel_values": pixel_values, 724 "pixel_values": pixel_values,
724 "attention_mask": inputs.attention_mask, 725 "attention_mask": inputs.attention_mask,
725 } 726 }
727
726 return batch 728 return batch
727 729
728 datamodule = VlpnDataModule( 730 datamodule = VlpnDataModule(