From 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 13:38:43 +0100 Subject: Fixed aspect ratio bucketing; allow passing token IDs to pipeline --- train_dreambooth.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 42a7d0f..79eede6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -699,9 +699,9 @@ def main(): return cond3 and cond4 def collate_fn(examples): - prompts = [example["prompts"] for example in examples] - cprompts = [example["cprompts"] for example in examples] - nprompts = [example["nprompts"] for example in examples] + prompt_ids = [example["prompt_ids"] for example in examples] + nprompt_ids = [example["nprompt_ids"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -713,16 +713,18 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + prompts = prompt_processor.unify_input_ids(prompt_ids) + nprompts = prompt_processor.unify_input_ids(nprompt_ids) inputs = prompt_processor.unify_input_ids(input_ids) batch = { - "prompts": prompts, - "cprompts": cprompts, - "nprompts": nprompts, + "prompt_ids": prompts.input_ids, + "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } + return batch datamodule = VlpnDataModule( -- cgit v1.2.3-54-g00ecf