diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 42a7d0f..79eede6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -699,9 +699,9 @@ def main(): | |||
699 | return cond3 and cond4 | 699 | return cond3 and cond4 |
700 | 700 | ||
701 | def collate_fn(examples): | 701 | def collate_fn(examples): |
702 | prompts = [example["prompts"] for example in examples] | 702 | prompt_ids = [example["prompt_ids"] for example in examples] |
703 | cprompts = [example["cprompts"] for example in examples] | 703 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
704 | nprompts = [example["nprompts"] for example in examples] | 704 | |
705 | input_ids = [example["instance_prompt_ids"] for example in examples] | 705 | input_ids = [example["instance_prompt_ids"] for example in examples] |
706 | pixel_values = [example["instance_images"] for example in examples] | 706 | pixel_values = [example["instance_images"] for example in examples] |
707 | 707 | ||
@@ -713,16 +713,18 @@ def main(): | |||
713 | pixel_values = torch.stack(pixel_values) | 713 | pixel_values = torch.stack(pixel_values) |
714 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | 714 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
715 | 715 | ||
716 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
717 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
716 | inputs = prompt_processor.unify_input_ids(input_ids) | 718 | inputs = prompt_processor.unify_input_ids(input_ids) |
717 | 719 | ||
718 | batch = { | 720 | batch = { |
719 | "prompts": prompts, | 721 | "prompt_ids": prompts.input_ids, |
720 | "cprompts": cprompts, | 722 | "nprompt_ids": nprompts.input_ids, |
721 | "nprompts": nprompts, | ||
722 | "input_ids": inputs.input_ids, | 723 | "input_ids": inputs.input_ids, |
723 | "pixel_values": pixel_values, | 724 | "pixel_values": pixel_values, |
724 | "attention_mask": inputs.attention_mask, | 725 | "attention_mask": inputs.attention_mask, |
725 | } | 726 | } |
727 | |||
726 | return batch | 728 | return batch |
727 | 729 | ||
728 | datamodule = VlpnDataModule( | 730 | datamodule = VlpnDataModule( |