diff options
| -rw-r--r-- | data/csv.py | 40 | ||||
| -rw-r--r-- | train_dreambooth.py | 47 | ||||
| -rw-r--r-- | train_ti.py | 47 |
3 files changed, 56 insertions, 78 deletions
diff --git a/data/csv.py b/data/csv.py index 4da5d64..803271b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple): | |||
| 41 | prompt: list[str] | 41 | prompt: list[str] |
| 42 | cprompt: str | 42 | cprompt: str |
| 43 | nprompt: str | 43 | nprompt: str |
| 44 | mode: list[str] | 44 | collection: list[str] |
| 45 | 45 | ||
| 46 | 46 | ||
| 47 | class CSVDataModule(): | 47 | class CSVDataModule(): |
| 48 | def __init__( | 48 | def __init__( |
| 49 | self, | 49 | self, |
| 50 | batch_size: int, | 50 | batch_size: int, |
| 51 | data_file: str, | 51 | data_file: str, |
| 52 | prompt_processor: PromptProcessor, | 52 | prompt_processor: PromptProcessor, |
| 53 | class_subdir: str = "cls", | 53 | class_subdir: str = "cls", |
| 54 | num_class_images: int = 1, | 54 | num_class_images: int = 1, |
| 55 | size: int = 768, | 55 | size: int = 768, |
| 56 | repeats: int = 1, | 56 | repeats: int = 1, |
| 57 | dropout: float = 0, | 57 | dropout: float = 0, |
| 58 | interpolation: str = "bicubic", | 58 | interpolation: str = "bicubic", |
| 59 | center_crop: bool = False, | 59 | center_crop: bool = False, |
| 60 | template_key: str = "template", | 60 | template_key: str = "template", |
| 61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
| 62 | generator: Optional[torch.Generator] = None, | 62 | generator: Optional[torch.Generator] = None, |
| 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
| 64 | collate_fn=None, | 64 | collate_fn=None, |
| 65 | num_workers: int = 0 | 65 | num_workers: int = 0 |
| 66 | ): | 66 | ): |
| 67 | super().__init__() | 67 | super().__init__() |
| 68 | 68 | ||
| @@ -112,7 +112,7 @@ class CSVDataModule(): | |||
| 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), | 112 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 113 | expansions | 113 | expansions |
| 114 | )), | 114 | )), |
| 115 | item["mode"].split(", ") if "mode" in item else [] | 115 | item["collection"].split(", ") if "collection" in item else [] |
| 116 | ) | 116 | ) |
| 117 | for item in data | 117 | for item in data |
| 118 | ] | 118 | ] |
| @@ -133,7 +133,7 @@ class CSVDataModule(): | |||
| 133 | item.prompt, | 133 | item.prompt, |
| 134 | item.cprompt, | 134 | item.cprompt, |
| 135 | item.nprompt, | 135 | item.nprompt, |
| 136 | item.mode, | 136 | item.collection, |
| 137 | ) | 137 | ) |
| 138 | for item in items | 138 | for item in items |
| 139 | for i in range(image_multiplier) | 139 | for i in range(image_multiplier) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 072150b..8fd78f1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -83,16 +83,10 @@ def parse_args(): | |||
| 83 | help="A token to use as initializer word." | 83 | help="A token to use as initializer word." |
| 84 | ) | 84 | ) |
| 85 | parser.add_argument( | 85 | parser.add_argument( |
| 86 | "--exclude_keywords", | 86 | "--exclude_collections", |
| 87 | type=str, | 87 | type=str, |
| 88 | nargs='*', | 88 | nargs='*', |
| 89 | help="Skip dataset items containing a listed keyword.", | 89 | help="Exclude all items with a listed collection.", |
| 90 | ) | ||
| 91 | parser.add_argument( | ||
| 92 | "--exclude_modes", | ||
| 93 | type=str, | ||
| 94 | nargs='*', | ||
| 95 | help="Exclude all items with a listed mode.", | ||
| 96 | ) | 90 | ) |
| 97 | parser.add_argument( | 91 | parser.add_argument( |
| 98 | "--train_text_encoder", | 92 | "--train_text_encoder", |
| @@ -142,10 +136,10 @@ def parse_args(): | |||
| 142 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 136 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 143 | ) | 137 | ) |
| 144 | parser.add_argument( | 138 | parser.add_argument( |
| 145 | "--mode", | 139 | "--collection", |
| 146 | type=str, | 140 | type=str, |
| 147 | default=None, | 141 | nargs='*', |
| 148 | help="A mode to filter the dataset.", | 142 | help="A collection to filter the dataset.", |
| 149 | ) | 143 | ) |
| 150 | parser.add_argument( | 144 | parser.add_argument( |
| 151 | "--seed", | 145 | "--seed", |
| @@ -391,11 +385,11 @@ def parse_args(): | |||
| 391 | if len(args.placeholder_token) != len(args.initializer_token): | 385 | if len(args.placeholder_token) != len(args.initializer_token): |
| 392 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | 386 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") |
| 393 | 387 | ||
| 394 | if isinstance(args.exclude_keywords, str): | 388 | if isinstance(args.collection, str): |
| 395 | args.exclude_keywords = [args.exclude_keywords] | 389 | args.collection = [args.collection] |
| 396 | 390 | ||
| 397 | if isinstance(args.exclude_modes, str): | 391 | if isinstance(args.exclude_collections, str): |
| 398 | args.exclude_modes = [args.exclude_modes] | 392 | args.exclude_collections = [args.exclude_collections] |
| 399 | 393 | ||
| 400 | if args.output_dir is None: | 394 | if args.output_dir is None: |
| 401 | raise ValueError("You must specify --output_dir") | 395 | raise ValueError("You must specify --output_dir") |
| @@ -655,17 +649,12 @@ def main(): | |||
| 655 | weight_dtype = torch.bfloat16 | 649 | weight_dtype = torch.bfloat16 |
| 656 | 650 | ||
| 657 | def keyword_filter(item: CSVDataItem): | 651 | def keyword_filter(item: CSVDataItem): |
| 658 | cond2 = args.exclude_keywords is None or not any( | 652 | cond3 = args.collection is None or args.collection in item.collection |
| 659 | keyword in part | 653 | cond4 = args.exclude_collections is None or not any( |
| 660 | for keyword in args.exclude_keywords | 654 | collection in item.collection |
| 661 | for part in item.prompt | 655 | for collection in args.exclude_collections |
| 662 | ) | ||
| 663 | cond3 = args.mode is None or args.mode in item.mode | ||
| 664 | cond4 = args.exclude_modes is None or not any( | ||
| 665 | mode in item.mode | ||
| 666 | for mode in args.exclude_modes | ||
| 667 | ) | 656 | ) |
| 668 | return cond2 and cond3 and cond4 | 657 | return cond3 and cond4 |
| 669 | 658 | ||
| 670 | def collate_fn(examples): | 659 | def collate_fn(examples): |
| 671 | prompts = [example["prompts"] for example in examples] | 660 | prompts = [example["prompts"] for example in examples] |
| @@ -813,10 +802,10 @@ def main(): | |||
| 813 | config = vars(args).copy() | 802 | config = vars(args).copy() |
| 814 | config["initializer_token"] = " ".join(config["initializer_token"]) | 803 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 815 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 804 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 816 | if config["exclude_modes"] is not None: | 805 | if config["collection"] is not None: |
| 817 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 806 | config["collection"] = " ".join(config["collection"]) |
| 818 | if config["exclude_keywords"] is not None: | 807 | if config["exclude_collections"] is not None: |
| 819 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 808 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
| 820 | accelerator.init_trackers("dreambooth", config=config) | 809 | accelerator.init_trackers("dreambooth", config=config) |
| 821 | 810 | ||
| 822 | # Train! | 811 | # Train! |
diff --git a/train_ti.py b/train_ti.py index 6aa4007..088c1a6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -93,16 +93,10 @@ def parse_args(): | |||
| 93 | help="The directory where class images will be saved.", | 93 | help="The directory where class images will be saved.", |
| 94 | ) | 94 | ) |
| 95 | parser.add_argument( | 95 | parser.add_argument( |
| 96 | "--exclude_keywords", | 96 | "--exclude_collections", |
| 97 | type=str, | 97 | type=str, |
| 98 | nargs='*', | 98 | nargs='*', |
| 99 | help="Skip dataset items containing a listed keyword.", | 99 | help="Exclude all items with a listed collection.", |
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 102 | "--exclude_modes", | ||
| 103 | type=str, | ||
| 104 | nargs='*', | ||
| 105 | help="Exclude all items with a listed mode.", | ||
| 106 | ) | 100 | ) |
| 107 | parser.add_argument( | 101 | parser.add_argument( |
| 108 | "--repeats", | 102 | "--repeats", |
| @@ -123,10 +117,10 @@ def parse_args(): | |||
| 123 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 117 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 124 | ) | 118 | ) |
| 125 | parser.add_argument( | 119 | parser.add_argument( |
| 126 | "--mode", | 120 | "--collection", |
| 127 | type=str, | 121 | type=str, |
| 128 | default=None, | 122 | nargs='*', |
| 129 | help="A mode to filter the dataset.", | 123 | help="A collection to filter the dataset.", |
| 130 | ) | 124 | ) |
| 131 | parser.add_argument( | 125 | parser.add_argument( |
| 132 | "--seed", | 126 | "--seed", |
| @@ -369,11 +363,11 @@ def parse_args(): | |||
| 369 | if len(args.placeholder_token) != len(args.initializer_token): | 363 | if len(args.placeholder_token) != len(args.initializer_token): |
| 370 | raise ValueError("You must specify --placeholder_token") | 364 | raise ValueError("You must specify --placeholder_token") |
| 371 | 365 | ||
| 372 | if isinstance(args.exclude_keywords, str): | 366 | if isinstance(args.collection, str): |
| 373 | args.exclude_keywords = [args.exclude_keywords] | 367 | args.collection = [args.collection] |
| 374 | 368 | ||
| 375 | if isinstance(args.exclude_modes, str): | 369 | if isinstance(args.exclude_collections, str): |
| 376 | args.exclude_modes = [args.exclude_modes] | 370 | args.exclude_collections = [args.exclude_collections] |
| 377 | 371 | ||
| 378 | if args.output_dir is None: | 372 | if args.output_dir is None: |
| 379 | raise ValueError("You must specify --output_dir") | 373 | raise ValueError("You must specify --output_dir") |
| @@ -600,17 +594,12 @@ def main(): | |||
| 600 | for keyword in args.placeholder_token | 594 | for keyword in args.placeholder_token |
| 601 | for part in item.prompt | 595 | for part in item.prompt |
| 602 | ) | 596 | ) |
| 603 | cond2 = args.exclude_keywords is None or not any( | 597 | cond3 = args.collection is None or args.collection in item.collection |
| 604 | keyword in part | 598 | cond4 = args.exclude_collections is None or not any( |
| 605 | for keyword in args.exclude_keywords | 599 | collection in item.collection |
| 606 | for part in item.prompt | 600 | for collection in args.exclude_collections |
| 607 | ) | ||
| 608 | cond3 = args.mode is None or args.mode in item.mode | ||
| 609 | cond4 = args.exclude_modes is None or not any( | ||
| 610 | mode in item.mode | ||
| 611 | for mode in args.exclude_modes | ||
| 612 | ) | 601 | ) |
| 613 | return cond1 and cond2 and cond3 and cond4 | 602 | return cond1 and cond3 and cond4 |
| 614 | 603 | ||
| 615 | def collate_fn(examples): | 604 | def collate_fn(examples): |
| 616 | prompts = [example["prompts"] for example in examples] | 605 | prompts = [example["prompts"] for example in examples] |
| @@ -827,10 +816,10 @@ def main(): | |||
| 827 | config = vars(args).copy() | 816 | config = vars(args).copy() |
| 828 | config["initializer_token"] = " ".join(config["initializer_token"]) | 817 | config["initializer_token"] = " ".join(config["initializer_token"]) |
| 829 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
| 830 | if config["exclude_modes"] is not None: | 819 | if config["collection"] is not None: |
| 831 | config["exclude_modes"] = " ".join(config["exclude_modes"]) | 820 | config["collection"] = " ".join(config["collection"]) |
| 832 | if config["exclude_keywords"] is not None: | 821 | if config["exclude_collections"] is not None: |
| 833 | config["exclude_keywords"] = " ".join(config["exclude_keywords"]) | 822 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
| 834 | accelerator.init_trackers("textual_inversion", config=config) | 823 | accelerator.init_trackers("textual_inversion", config=config) |
| 835 | 824 | ||
| 836 | # Train! | 825 | # Train! |
