summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-25 09:11:32 +0200
committerVolpeon <git@volpeon.ink>2023-06-25 09:11:32 +0200
commit0beac39e60fb4a79edb97a442884684d534722a4 (patch)
tree5a5f545155d64906378772d7a5fcbcc6fab2b430
parentUpdate (diff)
downloadtextual-inversion-diff-0beac39e60fb4a79edb97a442884684d534722a4.tar.gz
textual-inversion-diff-0beac39e60fb4a79edb97a442884684d534722a4.tar.bz2
textual-inversion-diff-0beac39e60fb4a79edb97a442884684d534722a4.zip
-rw-r--r--data/prompt.py18
-rw-r--r--train_dreambooth.py6
-rw-r--r--train_lora.py6
-rw-r--r--train_ti.py6
-rw-r--r--training/functional.py1
5 files changed, 18 insertions, 19 deletions
diff --git a/data/prompt.py b/data/prompt.py
new file mode 100644
index 0000000..0e66196
--- /dev/null
+++ b/data/prompt.py
@@ -0,0 +1,18 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
6
7 def __init__(self, prompt_ids: list[int], nprompt_ids: list[int]):
8 self.prompt_ids = prompt_ids
9 self.nprompt_ids = nprompt_ids
10
11 def __len__(self):
12 return len(self.prompts)
13
14 def __getitem__(self, index):
15 example = {}
16 example["prompt_ids"] = self.prompt_ids[index]
17 example["nprompt_ids"] = self.nprompt_ids[index]
18 return example
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 90ca467..dbe446d 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -204,12 +204,6 @@ def parse_args():
204 help="A collection to filter the dataset.", 204 help="A collection to filter the dataset.",
205 ) 205 )
206 parser.add_argument( 206 parser.add_argument(
207 "--validation_prompts",
208 type=str,
209 nargs="*",
210 help="Prompts for additional validation images",
211 )
212 parser.add_argument(
213 "--seed", type=int, default=None, help="A seed for reproducible training." 207 "--seed", type=int, default=None, help="A seed for reproducible training."
214 ) 208 )
215 parser.add_argument( 209 parser.add_argument(
diff --git a/train_lora.py b/train_lora.py
index eeac81f..5ab353c 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -236,12 +236,6 @@ def parse_args():
236 help="A collection to filter the dataset.", 236 help="A collection to filter the dataset.",
237 ) 237 )
238 parser.add_argument( 238 parser.add_argument(
239 "--validation_prompts",
240 type=str,
241 nargs="*",
242 help="Prompts for additional validation images",
243 )
244 parser.add_argument(
245 "--seed", type=int, default=None, help="A seed for reproducible training." 239 "--seed", type=int, default=None, help="A seed for reproducible training."
246 ) 240 )
247 parser.add_argument( 241 parser.add_argument(
diff --git a/train_ti.py b/train_ti.py
index a7d2924..2a599c1 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -160,12 +160,6 @@ def parse_args():
160 help="A collection to filter the dataset.", 160 help="A collection to filter the dataset.",
161 ) 161 )
162 parser.add_argument( 162 parser.add_argument(
163 "--validation_prompts",
164 type=str,
165 nargs="*",
166 help="Prompts for additional validation images",
167 )
168 parser.add_argument(
169 "--seed", type=int, default=None, help="A seed for reproducible training." 163 "--seed", type=int, default=None, help="A seed for reproducible training."
170 ) 164 )
171 parser.add_argument( 165 parser.add_argument(
diff --git a/training/functional.py b/training/functional.py
index b60afe3..75f5d14 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -111,7 +111,6 @@ def save_samples(
111 output_dir: Path, 111 output_dir: Path,
112 seed: int, 112 seed: int,
113 step: int, 113 step: int,
114 validation_prompts: list[str] = [],
115 cycle: int = 1, 114 cycle: int = 1,
116 batch_size: int = 1, 115 batch_size: int = 1,
117 num_batches: int = 1, 116 num_batches: int = 1,