From 0beac39e60fb4a79edb97a442884684d534722a4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 09:11:32 +0200 Subject: Update --- data/prompt.py | 18 ++++++++++++++++++ train_dreambooth.py | 6 ------ train_lora.py | 6 ------ train_ti.py | 6 ------ training/functional.py | 1 - 5 files changed, 18 insertions(+), 19 deletions(-) create mode 100644 data/prompt.py diff --git a/data/prompt.py b/data/prompt.py new file mode 100644 index 0000000..0e66196 --- /dev/null +++ b/data/prompt.py @@ -0,0 +1,18 @@ +from torch.utils.data import Dataset + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]): + self.prompt_ids = prompt_ids + self.nprompt_ids = nprompt_ids + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + example = {} + example["prompt_ids"] = self.prompt_ids[index] + example["nprompt_ids"] = self.nprompt_ids[index] + return example diff --git a/train_dreambooth.py b/train_dreambooth.py index 90ca467..dbe446d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -203,12 +203,6 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) - parser.add_argument( - "--validation_prompts", - type=str, - nargs="*", - help="Prompts for additional validation images", - ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) diff --git a/train_lora.py b/train_lora.py index eeac81f..5ab353c 100644 --- a/train_lora.py +++ b/train_lora.py @@ -235,12 +235,6 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) - parser.add_argument( - "--validation_prompts", - type=str, - nargs="*", - help="Prompts for additional validation images", - ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) diff --git a/train_ti.py b/train_ti.py index a7d2924..2a599c1 100644 --- a/train_ti.py +++ b/train_ti.py @@ -159,12 +159,6 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) - parser.add_argument( - "--validation_prompts", - type=str, - nargs="*", - help="Prompts for additional validation images", - ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) diff --git a/training/functional.py b/training/functional.py index b60afe3..75f5d14 100644 --- a/training/functional.py +++ b/training/functional.py @@ -111,7 +111,6 @@ def save_samples( output_dir: Path, seed: int, step: int, - validation_prompts: list[str] = [], cycle: int = 1, batch_size: int = 1, num_batches: int = 1, -- cgit v1.2.3-54-g00ecf