From 799a2ed9c9735d11887600ee57ebb7471cdf6f43 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 30 Dec 2022 14:04:59 +0100 Subject: Misc improvements --- data/csv.py | 40 ++++++++++++++++++++-------------------- train_dreambooth.py | 47 ++++++++++++++++++----------------------------- train_ti.py | 47 ++++++++++++++++++----------------------------- 3 files changed, 56 insertions(+), 78 deletions(-) diff --git a/data/csv.py b/data/csv.py index 4da5d64..803271b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -41,28 +41,28 @@ class CSVDataItem(NamedTuple): prompt: list[str] cprompt: str nprompt: str - mode: list[str] + collection: list[str] class CSVDataModule(): def __init__( - self, - batch_size: int, - data_file: str, - prompt_processor: PromptProcessor, - class_subdir: str = "cls", - num_class_images: int = 1, - size: int = 768, - repeats: int = 1, - dropout: float = 0, - interpolation: str = "bicubic", - center_crop: bool = False, - template_key: str = "template", - valid_set_size: Optional[int] = None, - generator: Optional[torch.Generator] = None, - filter: Optional[Callable[[CSVDataItem], bool]] = None, - collate_fn=None, - num_workers: int = 0 + self, + batch_size: int, + data_file: str, + prompt_processor: PromptProcessor, + class_subdir: str = "cls", + num_class_images: int = 1, + size: int = 768, + repeats: int = 1, + dropout: float = 0, + interpolation: str = "bicubic", + center_crop: bool = False, + template_key: str = "template", + valid_set_size: Optional[int] = None, + generator: Optional[torch.Generator] = None, + filter: Optional[Callable[[CSVDataItem], bool]] = None, + collate_fn=None, + num_workers: int = 0 ): super().__init__() @@ -112,7 +112,7 @@ class CSVDataModule(): nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), expansions )), - item["mode"].split(", ") if "mode" in item else [] + item["collection"].split(", ") if "collection" in item else [] ) for item in data ] @@ -133,7 +133,7 @@ class CSVDataModule(): item.prompt, item.cprompt, item.nprompt, - item.mode, + item.collection, ) for item in items for i in range(image_multiplier) diff --git a/train_dreambooth.py b/train_dreambooth.py index 072150b..8fd78f1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -83,16 +83,10 @@ def parse_args(): help="A token to use as initializer word." ) parser.add_argument( - "--exclude_keywords", + "--exclude_collections", type=str, nargs='*', - help="Skip dataset items containing a listed keyword.", - ) - parser.add_argument( - "--exclude_modes", - type=str, - nargs='*', - help="Exclude all items with a listed mode.", + help="Exclude all items with a listed collection.", ) parser.add_argument( "--train_text_encoder", @@ -142,10 +136,10 @@ def parse_args(): help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( - "--mode", + "--collection", type=str, - default=None, - help="A mode to filter the dataset.", + nargs='*', + help="A collection to filter the dataset.", ) parser.add_argument( "--seed", @@ -391,11 +385,11 @@ def parse_args(): if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("Number of items in --placeholder_token and --initializer_token must match") - if isinstance(args.exclude_keywords, str): - args.exclude_keywords = [args.exclude_keywords] + if isinstance(args.collection, str): + args.collection = [args.collection] - if isinstance(args.exclude_modes, str): - args.exclude_modes = [args.exclude_modes] + if isinstance(args.exclude_collections, str): + args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -655,17 +649,12 @@ def main(): weight_dtype = torch.bfloat16 def keyword_filter(item: CSVDataItem): - cond2 = args.exclude_keywords is None or not any( - keyword in part - for keyword in args.exclude_keywords - for part in item.prompt - ) - cond3 = args.mode is None or args.mode in item.mode - cond4 = args.exclude_modes is None or not any( - mode in item.mode - for mode in args.exclude_modes + cond3 = args.collection is None or args.collection in item.collection + cond4 = args.exclude_collections is None or not any( + collection in item.collection + for collection in args.exclude_collections ) - return cond2 and cond3 and cond4 + return cond3 and cond4 def collate_fn(examples): prompts = [example["prompts"] for example in examples] @@ -813,10 +802,10 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) - if config["exclude_modes"] is not None: - config["exclude_modes"] = " ".join(config["exclude_modes"]) - if config["exclude_keywords"] is not None: - config["exclude_keywords"] = " ".join(config["exclude_keywords"]) + if config["collection"] is not None: + config["collection"] = " ".join(config["collection"]) + if config["exclude_collections"] is not None: + config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("dreambooth", config=config) # Train! diff --git a/train_ti.py b/train_ti.py index 6aa4007..088c1a6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -93,16 +93,10 @@ def parse_args(): help="The directory where class images will be saved.", ) parser.add_argument( - "--exclude_keywords", + "--exclude_collections", type=str, nargs='*', - help="Skip dataset items containing a listed keyword.", - ) - parser.add_argument( - "--exclude_modes", - type=str, - nargs='*', - help="Exclude all items with a listed mode.", + help="Exclude all items with a listed collection.", ) parser.add_argument( "--repeats", @@ -123,10 +117,10 @@ def parse_args(): help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( - "--mode", + "--collection", type=str, - default=None, - help="A mode to filter the dataset.", + nargs='*', + help="A collection to filter the dataset.", ) parser.add_argument( "--seed", @@ -369,11 +363,11 @@ def parse_args(): if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("You must specify --placeholder_token") - if isinstance(args.exclude_keywords, str): - args.exclude_keywords = [args.exclude_keywords] + if isinstance(args.collection, str): + args.collection = [args.collection] - if isinstance(args.exclude_modes, str): - args.exclude_modes = [args.exclude_modes] + if isinstance(args.exclude_collections, str): + args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -600,17 +594,12 @@ def main(): for keyword in args.placeholder_token for part in item.prompt ) - cond2 = args.exclude_keywords is None or not any( - keyword in part - for keyword in args.exclude_keywords - for part in item.prompt - ) - cond3 = args.mode is None or args.mode in item.mode - cond4 = args.exclude_modes is None or not any( - mode in item.mode - for mode in args.exclude_modes + cond3 = args.collection is None or args.collection in item.collection + cond4 = args.exclude_collections is None or not any( + collection in item.collection + for collection in args.exclude_collections ) - return cond1 and cond2 and cond3 and cond4 + return cond1 and cond3 and cond4 def collate_fn(examples): prompts = [example["prompts"] for example in examples] @@ -827,10 +816,10 @@ def main(): config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) config["placeholder_token"] = " ".join(config["placeholder_token"]) - if config["exclude_modes"] is not None: - config["exclude_modes"] = " ".join(config["exclude_modes"]) - if config["exclude_keywords"] is not None: - config["exclude_keywords"] = " ".join(config["exclude_keywords"]) + if config["collection"] is not None: + config["collection"] = " ".join(config["collection"]) + if config["exclude_collections"] is not None: + config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("textual_inversion", config=config) # Train! -- cgit v1.2.3-54-g00ecf