summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py40
-rw-r--r--train_dreambooth.py47
-rw-r--r--train_ti.py47
3 files changed, 56 insertions, 78 deletions
diff --git a/data/csv.py b/data/csv.py
index 4da5d64..803271b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple):
41 prompt: list[str] 41 prompt: list[str]
42 cprompt: str 42 cprompt: str
43 nprompt: str 43 nprompt: str
44 mode: list[str] 44 collection: list[str]
45 45
46 46
47class CSVDataModule(): 47class CSVDataModule():
48 def __init__( 48 def __init__(
49 self, 49 self,
50 batch_size: int, 50 batch_size: int,
51 data_file: str, 51 data_file: str,
52 prompt_processor: PromptProcessor, 52 prompt_processor: PromptProcessor,
53 class_subdir: str = "cls", 53 class_subdir: str = "cls",
54 num_class_images: int = 1, 54 num_class_images: int = 1,
55 size: int = 768, 55 size: int = 768,
56 repeats: int = 1, 56 repeats: int = 1,
57 dropout: float = 0, 57 dropout: float = 0,
58 interpolation: str = "bicubic", 58 interpolation: str = "bicubic",
59 center_crop: bool = False, 59 center_crop: bool = False,
60 template_key: str = "template", 60 template_key: str = "template",
61 valid_set_size: Optional[int] = None, 61 valid_set_size: Optional[int] = None,
62 generator: Optional[torch.Generator] = None, 62 generator: Optional[torch.Generator] = None,
63 filter: Optional[Callable[[CSVDataItem], bool]] = None, 63 filter: Optional[Callable[[CSVDataItem], bool]] = None,
64 collate_fn=None, 64 collate_fn=None,
65 num_workers: int = 0 65 num_workers: int = 0
66 ): 66 ):
67 super().__init__() 67 super().__init__()
68 68
@@ -112,7 +112,7 @@ class CSVDataModule():
112 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), 112 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
113 expansions 113 expansions
114 )), 114 )),
115 item["mode"].split(", ") if "mode" in item else [] 115 item["collection"].split(", ") if "collection" in item else []
116 ) 116 )
117 for item in data 117 for item in data
118 ] 118 ]
@@ -133,7 +133,7 @@ class CSVDataModule():
133 item.prompt, 133 item.prompt,
134 item.cprompt, 134 item.cprompt,
135 item.nprompt, 135 item.nprompt,
136 item.mode, 136 item.collection,
137 ) 137 )
138 for item in items 138 for item in items
139 for i in range(image_multiplier) 139 for i in range(image_multiplier)
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 072150b..8fd78f1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -83,16 +83,10 @@ def parse_args():
83 help="A token to use as initializer word." 83 help="A token to use as initializer word."
84 ) 84 )
85 parser.add_argument( 85 parser.add_argument(
86 "--exclude_keywords", 86 "--exclude_collections",
87 type=str, 87 type=str,
88 nargs='*', 88 nargs='*',
89 help="Skip dataset items containing a listed keyword.", 89 help="Exclude all items with a listed collection.",
90 )
91 parser.add_argument(
92 "--exclude_modes",
93 type=str,
94 nargs='*',
95 help="Exclude all items with a listed mode.",
96 ) 90 )
97 parser.add_argument( 91 parser.add_argument(
98 "--train_text_encoder", 92 "--train_text_encoder",
@@ -142,10 +136,10 @@ def parse_args():
142 help="The embeddings directory where Textual Inversion embeddings are stored.", 136 help="The embeddings directory where Textual Inversion embeddings are stored.",
143 ) 137 )
144 parser.add_argument( 138 parser.add_argument(
145 "--mode", 139 "--collection",
146 type=str, 140 type=str,
147 default=None, 141 nargs='*',
148 help="A mode to filter the dataset.", 142 help="A collection to filter the dataset.",
149 ) 143 )
150 parser.add_argument( 144 parser.add_argument(
151 "--seed", 145 "--seed",
@@ -391,11 +385,11 @@ def parse_args():
391 if len(args.placeholder_token) != len(args.initializer_token): 385 if len(args.placeholder_token) != len(args.initializer_token):
392 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 386 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
393 387
394 if isinstance(args.exclude_keywords, str): 388 if isinstance(args.collection, str):
395 args.exclude_keywords = [args.exclude_keywords] 389 args.collection = [args.collection]
396 390
397 if isinstance(args.exclude_modes, str): 391 if isinstance(args.exclude_collections, str):
398 args.exclude_modes = [args.exclude_modes] 392 args.exclude_collections = [args.exclude_collections]
399 393
400 if args.output_dir is None: 394 if args.output_dir is None:
401 raise ValueError("You must specify --output_dir") 395 raise ValueError("You must specify --output_dir")
@@ -655,17 +649,12 @@ def main():
655 weight_dtype = torch.bfloat16 649 weight_dtype = torch.bfloat16
656 650
657 def keyword_filter(item: CSVDataItem): 651 def keyword_filter(item: CSVDataItem):
658 cond2 = args.exclude_keywords is None or not any( 652 cond3 = args.collection is None or args.collection in item.collection
659 keyword in part 653 cond4 = args.exclude_collections is None or not any(
660 for keyword in args.exclude_keywords 654 collection in item.collection
661 for part in item.prompt 655 for collection in args.exclude_collections
662 )
663 cond3 = args.mode is None or args.mode in item.mode
664 cond4 = args.exclude_modes is None or not any(
665 mode in item.mode
666 for mode in args.exclude_modes
667 ) 656 )
668 return cond2 and cond3 and cond4 657 return cond3 and cond4
669 658
670 def collate_fn(examples): 659 def collate_fn(examples):
671 prompts = [example["prompts"] for example in examples] 660 prompts = [example["prompts"] for example in examples]
@@ -813,10 +802,10 @@ def main():
813 config = vars(args).copy() 802 config = vars(args).copy()
814 config["initializer_token"] = " ".join(config["initializer_token"]) 803 config["initializer_token"] = " ".join(config["initializer_token"])
815 config["placeholder_token"] = " ".join(config["placeholder_token"]) 804 config["placeholder_token"] = " ".join(config["placeholder_token"])
816 if config["exclude_modes"] is not None: 805 if config["collection"] is not None:
817 config["exclude_modes"] = " ".join(config["exclude_modes"]) 806 config["collection"] = " ".join(config["collection"])
818 if config["exclude_keywords"] is not None: 807 if config["exclude_collections"] is not None:
819 config["exclude_keywords"] = " ".join(config["exclude_keywords"]) 808 config["exclude_collections"] = " ".join(config["exclude_collections"])
820 accelerator.init_trackers("dreambooth", config=config) 809 accelerator.init_trackers("dreambooth", config=config)
821 810
822 # Train! 811 # Train!
diff --git a/train_ti.py b/train_ti.py
index 6aa4007..088c1a6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -93,16 +93,10 @@ def parse_args():
93 help="The directory where class images will be saved.", 93 help="The directory where class images will be saved.",
94 ) 94 )
95 parser.add_argument( 95 parser.add_argument(
96 "--exclude_keywords", 96 "--exclude_collections",
97 type=str, 97 type=str,
98 nargs='*', 98 nargs='*',
99 help="Skip dataset items containing a listed keyword.", 99 help="Exclude all items with a listed collection.",
100 )
101 parser.add_argument(
102 "--exclude_modes",
103 type=str,
104 nargs='*',
105 help="Exclude all items with a listed mode.",
106 ) 100 )
107 parser.add_argument( 101 parser.add_argument(
108 "--repeats", 102 "--repeats",
@@ -123,10 +117,10 @@ def parse_args():
123 help="The embeddings directory where Textual Inversion embeddings are stored.", 117 help="The embeddings directory where Textual Inversion embeddings are stored.",
124 ) 118 )
125 parser.add_argument( 119 parser.add_argument(
126 "--mode", 120 "--collection",
127 type=str, 121 type=str,
128 default=None, 122 nargs='*',
129 help="A mode to filter the dataset.", 123 help="A collection to filter the dataset.",
130 ) 124 )
131 parser.add_argument( 125 parser.add_argument(
132 "--seed", 126 "--seed",
@@ -369,11 +363,11 @@ def parse_args():
369 if len(args.placeholder_token) != len(args.initializer_token): 363 if len(args.placeholder_token) != len(args.initializer_token):
370 raise ValueError("You must specify --placeholder_token") 364 raise ValueError("You must specify --placeholder_token")
371 365
372 if isinstance(args.exclude_keywords, str): 366 if isinstance(args.collection, str):
373 args.exclude_keywords = [args.exclude_keywords] 367 args.collection = [args.collection]
374 368
375 if isinstance(args.exclude_modes, str): 369 if isinstance(args.exclude_collections, str):
376 args.exclude_modes = [args.exclude_modes] 370 args.exclude_collections = [args.exclude_collections]
377 371
378 if args.output_dir is None: 372 if args.output_dir is None:
379 raise ValueError("You must specify --output_dir") 373 raise ValueError("You must specify --output_dir")
@@ -600,17 +594,12 @@ def main():
600 for keyword in args.placeholder_token 594 for keyword in args.placeholder_token
601 for part in item.prompt 595 for part in item.prompt
602 ) 596 )
603 cond2 = args.exclude_keywords is None or not any( 597 cond3 = args.collection is None or args.collection in item.collection
604 keyword in part 598 cond4 = args.exclude_collections is None or not any(
605 for keyword in args.exclude_keywords 599 collection in item.collection
606 for part in item.prompt 600 for collection in args.exclude_collections
607 )
608 cond3 = args.mode is None or args.mode in item.mode
609 cond4 = args.exclude_modes is None or not any(
610 mode in item.mode
611 for mode in args.exclude_modes
612 ) 601 )
613 return cond1 and cond2 and cond3 and cond4 602 return cond1 and cond3 and cond4
614 603
615 def collate_fn(examples): 604 def collate_fn(examples):
616 prompts = [example["prompts"] for example in examples] 605 prompts = [example["prompts"] for example in examples]
@@ -827,10 +816,10 @@ def main():
827 config = vars(args).copy() 816 config = vars(args).copy()
828 config["initializer_token"] = " ".join(config["initializer_token"]) 817 config["initializer_token"] = " ".join(config["initializer_token"])
829 config["placeholder_token"] = " ".join(config["placeholder_token"]) 818 config["placeholder_token"] = " ".join(config["placeholder_token"])
830 if config["exclude_modes"] is not None: 819 if config["collection"] is not None:
831 config["exclude_modes"] = " ".join(config["exclude_modes"]) 820 config["collection"] = " ".join(config["collection"])
832 if config["exclude_keywords"] is not None: 821 if config["exclude_collections"] is not None:
833 config["exclude_keywords"] = " ".join(config["exclude_keywords"]) 822 config["exclude_collections"] = " ".join(config["exclude_collections"])
834 accelerator.init_trackers("textual_inversion", config=config) 823 accelerator.init_trackers("textual_inversion", config=config)
835 824
836 # Train! 825 # Train!