diff options
| -rw-r--r-- | data/prompt.py | 18 | ||||
| -rw-r--r-- | train_dreambooth.py | 6 | ||||
| -rw-r--r-- | train_lora.py | 6 | ||||
| -rw-r--r-- | train_ti.py | 6 | ||||
| -rw-r--r-- | training/functional.py | 1 |
5 files changed, 18 insertions, 19 deletions
diff --git a/data/prompt.py b/data/prompt.py new file mode 100644 index 0000000..0e66196 --- /dev/null +++ b/data/prompt.py | |||
| @@ -0,0 +1,18 @@ | |||
| 1 | from torch.utils.data import Dataset | ||
| 2 | |||
| 3 | |||
| 4 | class PromptDataset(Dataset): | ||
| 5 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." | ||
| 6 | |||
| 7 | def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]): | ||
| 8 | self.prompt_ids = prompt_ids | ||
| 9 | self.nprompt_ids = nprompt_ids | ||
| 10 | |||
| 11 | def __len__(self): | ||
| 12 | return len(self.prompts) | ||
| 13 | |||
| 14 | def __getitem__(self, index): | ||
| 15 | example = {} | ||
| 16 | example["prompt_ids"] = self.prompt_ids[index] | ||
| 17 | example["nprompt_ids"] = self.nprompt_ids[index] | ||
| 18 | return example | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 90ca467..dbe446d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -204,12 +204,6 @@ def parse_args(): | |||
| 204 | help="A collection to filter the dataset.", | 204 | help="A collection to filter the dataset.", |
| 205 | ) | 205 | ) |
| 206 | parser.add_argument( | 206 | parser.add_argument( |
| 207 | "--validation_prompts", | ||
| 208 | type=str, | ||
| 209 | nargs="*", | ||
| 210 | help="Prompts for additional validation images", | ||
| 211 | ) | ||
| 212 | parser.add_argument( | ||
| 213 | "--seed", type=int, default=None, help="A seed for reproducible training." | 207 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 214 | ) | 208 | ) |
| 215 | parser.add_argument( | 209 | parser.add_argument( |
diff --git a/train_lora.py b/train_lora.py index eeac81f..5ab353c 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -236,12 +236,6 @@ def parse_args(): | |||
| 236 | help="A collection to filter the dataset.", | 236 | help="A collection to filter the dataset.", |
| 237 | ) | 237 | ) |
| 238 | parser.add_argument( | 238 | parser.add_argument( |
| 239 | "--validation_prompts", | ||
| 240 | type=str, | ||
| 241 | nargs="*", | ||
| 242 | help="Prompts for additional validation images", | ||
| 243 | ) | ||
| 244 | parser.add_argument( | ||
| 245 | "--seed", type=int, default=None, help="A seed for reproducible training." | 239 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 246 | ) | 240 | ) |
| 247 | parser.add_argument( | 241 | parser.add_argument( |
diff --git a/train_ti.py b/train_ti.py index a7d2924..2a599c1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -160,12 +160,6 @@ def parse_args(): | |||
| 160 | help="A collection to filter the dataset.", | 160 | help="A collection to filter the dataset.", |
| 161 | ) | 161 | ) |
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| 163 | "--validation_prompts", | ||
| 164 | type=str, | ||
| 165 | nargs="*", | ||
| 166 | help="Prompts for additional validation images", | ||
| 167 | ) | ||
| 168 | parser.add_argument( | ||
| 169 | "--seed", type=int, default=None, help="A seed for reproducible training." | 163 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 170 | ) | 164 | ) |
| 171 | parser.add_argument( | 165 | parser.add_argument( |
diff --git a/training/functional.py b/training/functional.py index b60afe3..75f5d14 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -111,7 +111,6 @@ def save_samples( | |||
| 111 | output_dir: Path, | 111 | output_dir: Path, |
| 112 | seed: int, | 112 | seed: int, |
| 113 | step: int, | 113 | step: int, |
| 114 | validation_prompts: list[str] = [], | ||
| 115 | cycle: int = 1, | 114 | cycle: int = 1, |
| 116 | batch_size: int = 1, | 115 | batch_size: int = 1, |
| 117 | num_batches: int = 1, | 116 | num_batches: int = 1, |
