diff options
-rw-r--r-- | data/csv.py | 40 | ||||
-rw-r--r-- | train_dreambooth.py | 47 | ||||
-rw-r--r-- | train_ti.py | 47 |
3 files changed, 56 insertions, 78 deletions
diff --git a/data/csv.py b/data/csv.py index 4da5d64..803271b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple): | |||
41 | prompt: list[str] | 41 | prompt: list[str] |
42 | cprompt: str | 42 | cprompt: str |
43 | nprompt: str | 43 | nprompt: str |
44 | mode: list[str] | 44 | collection: list[str] |
45 | 45 | ||
46 | 46 | ||
47 | class CSVDataModule(): | 47 | class CSVDataModule(): |
48 | def __init__( | 48 | def __init__( |
49 | self, | 49 | self, |
50 | batch_size: int, | 50 | batch_size: int, |
51 | data_file: str, | 51 | data_file: str, |
52 | prompt_processor: PromptProcessor, | 52 | prompt_processor: PromptProcessor, |
53 | class_subdir: str = "cls", | 53 | class_subdir: str = "cls", |
54 | num_class_images: int = 1, | 54 | num_class_images: int = 1, |
55 | size: int = 768, | 55 | size: int = 768, |
56 | repeats: int = 1, | 56 | repeats: int = 1, |
57 | dropout: float = 0, | 57 | dropout: float = 0, |
58 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
59 | center_crop: bool = False, | 59 | center_crop: bool = False, |
60 | template_key: str = "template", | 60 | template_key: str = "template", |
61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
64 | collate_fn=None, | 64 | collate_fn=None, |
65 | num_workers: int = 0 | 65 | num_workers: int = 0 |
66 | ): | 66 | ): |
67 | super().__init__() | 67 | super().__init__() |
68 | 68 | ||
@@ -112,7 +112,7 @@ class CSVDataModule(): | |||
112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
113 | expansions | 113 | expansions |
114 | )), | 114 | )), |
115 | item["mode"].split(", ") if "mode" in item else [] | 115 | item["collection"].split(", ") if "collection" in item else [] |
116 | ) | 116 | ) |
117 | for item in data | 117 | for item in data |
118 | ] | 118 | ] |
@@ -133,7 +133,7 @@ class CSVDataModule(): | |||
133 | item.prompt, | 133 | item.prompt, |
134 | item.cprompt, | 134 | item.cprompt, |
135 | item.nprompt, | 135 | item.nprompt, |
136 | item.mode, | 136 | item.collection, |
137 | ) | 137 | ) |
138 | for item in items | 138 | for item in items |
139 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 072150b..8fd78f1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -83,16 +83,10 @@ def parse_args(): | |||
83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
84 | ) | 84 | ) |
85 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--exclude_keywords", | 86 | "--exclude_collections", |
87 | type=str, | 87 | type=str, |
88 | nargs='*', | 88 | nargs='*', |
89 | help="Skip dataset items containing a listed keyword.", | 89 | help="Exclude all items with a listed collection.", |
90 | ) | ||
91 | parser.add_argument( | ||
92 | "--exclude_modes", | ||
93 | type=str, | ||
94 | nargs='*', | ||
95 | help="Exclude all items with a listed mode.", | ||
96 | ) | 90 | ) |
97 | parser.add_argument( | 91 | parser.add_argument( |
98 | "--train_text_encoder", | 92 | "--train_text_encoder", |
@@ -142,10 +136,10 @@ def parse_args(): | |||
142 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 136 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
143 | ) | 137 | ) |
144 | parser.add_argument( | 138 | parser.add_argument( |
145 | "--mode", | 139 | "--collection", |
146 | type=str, | 140 | type=str, |
147 | default=None, | 141 | nargs='*', |
148 | help="A mode to filter the dataset.", | 142 | help="A collection to filter the dataset.", |
149 | ) | 143 | ) |
150 | parser.add_argument( | 144 | parser.add_argument( |
151 | "--seed", | 145 | "--seed", |
@@ -391,11 +385,11 @@ def parse_args(): | |||
391 | if len(args.placeholder_token) != len(args.initializer_token): | 385 | if len(args.placeholder_token) != len(args.initializer_token): |
392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 386 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
393 | 387 | ||
394 | if isinstance(args.exclude_keywords, str): | 388 | if isinstance(args.collection, str): |
395 | args.exclude_keywords = [args.exclude_keywords] | 389 | args.collection = [args.collection] |
396 | 390 | ||
397 | if isinstance(args.exclude_modes, str): | 391 | if isinstance(args.exclude_collections, str): |
398 | args.exclude_modes = [args.exclude_modes] | 392 | args.exclude_collections = [args.exclude_collections] |
399 | 393 | ||
400 | if args.output_dir is None: | 394 | if args.output_dir is None: |
401 | raise ValueError("You must specify --output_dir") | 395 | raise ValueError("You must specify --output_dir") |
@@ -655,17 +649,12 @@ def main(): | |||
655 | weight_dtype = torch.bfloat16 | 649 | weight_dtype = torch.bfloat16 |
656 | 650 | ||
657 | def keyword_filter(item: CSVDataItem): | 651 | def keyword_filter(item: CSVDataItem): |
658 | cond2 = args.exclude_keywords is None or not any( | 652 | cond3 = args.collection is None or args.collection in item.collection |
659 | keyword in part | 653 | cond4 = args.exclude_collections is None or not any( |
660 | for keyword in args.exclude_keywords | 654 | collection in item.collection |
661 | for part in item.prompt | 655 | for collection in args.exclude_collections |
662 | ) | ||
663 | cond3 = args.mode is None or args.mode in item.mode | ||
664 | cond4 = args.exclude_modes is None or not any( | ||
665 | mode in item.mode | ||
666 | for mode in args.exclude_modes | ||
667 | ) | 656 | ) |
668 | return cond2 and cond3 and cond4 | 657 | return cond3 and cond4 |
669 | 658 | ||
670 | def collate_fn(examples): | 659 | def collate_fn(examples): |
671 | prompts = [example["prompts"] for example in examples] | 660 | prompts = [example["prompts"] for example in examples] |
@@ -813,10 +802,10 @@ def main(): | |||
813 | config = vars(args).copy() | 802 | config = vars(args).copy() |
814 | config["initializer_token"] = " ".join(config["initializer_token"]) | 803 | config["initializer_token"] = " ".join(config["initializer_token"]) |
815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 804 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
816 | if config["exclude_modes"] is not None: | 805 | if config["collection"] is not None: |
817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 806 | config["collection"] = " ".join(config["collection"]) |
818 | if config["exclude_keywords"] is not None: | 807 | if config["exclude_collections"] is not None: |
819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 808 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
820 | accelerator.init_trackers("dreambooth", config=config) | 809 | accelerator.init_trackers("dreambooth", config=config) |
821 | 810 | ||
822 | # Train! | 811 | # Train! |
diff --git a/train_ti.py b/train_ti.py index 6aa4007..088c1a6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -93,16 +93,10 @@ def parse_args(): | |||
93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
94 | ) | 94 | ) |
95 | parser.add_argument( | 95 | parser.add_argument( |
96 | "--exclude_keywords", | 96 | "--exclude_collections", |
97 | type=str, | 97 | type=str, |
98 | nargs='*', | 98 | nargs='*', |
99 | help="Skip dataset items containing a listed keyword.", | 99 | help="Exclude all items with a listed collection.", |
100 | ) | ||
101 | parser.add_argument( | ||
102 | "--exclude_modes", | ||
103 | type=str, | ||
104 | nargs='*', | ||
105 | help="Exclude all items with a listed mode.", | ||
106 | ) | 100 | ) |
107 | parser.add_argument( | 101 | parser.add_argument( |
108 | "--repeats", | 102 | "--repeats", |
@@ -123,10 +117,10 @@ def parse_args(): | |||
123 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 117 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
124 | ) | 118 | ) |
125 | parser.add_argument( | 119 | parser.add_argument( |
126 | "--mode", | 120 | "--collection", |
127 | type=str, | 121 | type=str, |
128 | default=None, | 122 | nargs='*', |
129 | help="A mode to filter the dataset.", | 123 | help="A collection to filter the dataset.", |
130 | ) | 124 | ) |
131 | parser.add_argument( | 125 | parser.add_argument( |
132 | "--seed", | 126 | "--seed", |
@@ -369,11 +363,11 @@ def parse_args(): | |||
369 | if len(args.placeholder_token) != len(args.initializer_token): | 363 | if len(args.placeholder_token) != len(args.initializer_token): |
370 | raise ValueError("You must specify --placeholder_token") | 364 | raise ValueError("You must specify --placeholder_token") |
371 | 365 | ||
372 | if isinstance(args.exclude_keywords, str): | 366 | if isinstance(args.collection, str): |
373 | args.exclude_keywords = [args.exclude_keywords] | 367 | args.collection = [args.collection] |
374 | 368 | ||
375 | if isinstance(args.exclude_modes, str): | 369 | if isinstance(args.exclude_collections, str): |
376 | args.exclude_modes = [args.exclude_modes] | 370 | args.exclude_collections = [args.exclude_collections] |
377 | 371 | ||
378 | if args.output_dir is None: | 372 | if args.output_dir is None: |
379 | raise ValueError("You must specify --output_dir") | 373 | raise ValueError("You must specify --output_dir") |
@@ -600,17 +594,12 @@ def main(): | |||
600 | for keyword in args.placeholder_token | 594 | for keyword in args.placeholder_token |
601 | for part in item.prompt | 595 | for part in item.prompt |
602 | ) | 596 | ) |
603 | cond2 = args.exclude_keywords is None or not any( | 597 | cond3 = args.collection is None or args.collection in item.collection |
604 | keyword in part | 598 | cond4 = args.exclude_collections is None or not any( |
605 | for keyword in args.exclude_keywords | 599 | collection in item.collection |
606 | for part in item.prompt | 600 | for collection in args.exclude_collections |
607 | ) | ||
608 | cond3 = args.mode is None or args.mode in item.mode | ||
609 | cond4 = args.exclude_modes is None or not any( | ||
610 | mode in item.mode | ||
611 | for mode in args.exclude_modes | ||
612 | ) | 601 | ) |
613 | return cond1 and cond2 and cond3 and cond4 | 602 | return cond1 and cond3 and cond4 |
614 | 603 | ||
615 | def collate_fn(examples): | 604 | def collate_fn(examples): |
616 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
@@ -827,10 +816,10 @@ def main(): | |||
827 | config = vars(args).copy() | 816 | config = vars(args).copy() |
828 | config["initializer_token"] = " ".join(config["initializer_token"]) | 817 | config["initializer_token"] = " ".join(config["initializer_token"]) |
829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
830 | if config["exclude_modes"] is not None: | 819 | if config["collection"] is not None: |
831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 820 | config["collection"] = " ".join(config["collection"]) |
832 | if config["exclude_keywords"] is not None: | 821 | if config["exclude_collections"] is not None: |
833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 822 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
834 | accelerator.init_trackers("textual_inversion", config=config) | 823 | accelerator.init_trackers("textual_inversion", config=config) |
835 | 824 | ||
836 | # Train! | 825 | # Train! |