diff options
| -rw-r--r-- | data/csv.py | 38 | ||||
| -rw-r--r-- | dreambooth.py | 6 | ||||
| -rw-r--r-- | textual_inversion.py | 13 |
3 files changed, 47 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index 9125212..9c3c3f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 54 | dropout: float = 0, | 54 | dropout: float = 0, |
| 55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
| 56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
| 57 | template_key: str = "template", | ||
| 57 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
| 58 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
| 60 | keyword_filter: list[str] = [], | ||
| 59 | collate_fn=None, | 61 | collate_fn=None, |
| 60 | num_workers: int = 0 | 62 | num_workers: int = 0 |
| 61 | ): | 63 | ): |
| @@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 78 | self.repeats = repeats | 80 | self.repeats = repeats |
| 79 | self.dropout = dropout | 81 | self.dropout = dropout |
| 80 | self.center_crop = center_crop | 82 | self.center_crop = center_crop |
| 83 | self.template_key = template_key | ||
| 81 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
| 82 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
| 83 | self.generator = generator | 86 | self.generator = generator |
| 87 | self.keyword_filter = keyword_filter | ||
| 84 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
| 85 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
| 86 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
| 87 | 91 | ||
| 88 | def prepare_subdata(self, template, data, num_class_images=1): | 92 | def prepare_items(self, template, data) -> list[CSVDataItem]: |
| 89 | image = template["image"] if "image" in template else "{}" | 93 | image = template["image"] if "image" in template else "{}" |
| 90 | prompt = template["prompt"] if "prompt" in template else "{content}" | 94 | prompt = template["prompt"] if "prompt" in template else "{content}" |
| 91 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 95 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| 92 | 96 | ||
| 93 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
| 94 | |||
| 95 | return [ | 97 | return [ |
| 96 | CSVDataItem( | 98 | CSVDataItem( |
| 97 | self.data_root.joinpath(image.format(item["image"])), | 99 | self.data_root.joinpath(image.format(item["image"])), |
| 98 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), | 100 | None, |
| 99 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 100 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
| 101 | ) | 103 | ) |
| 102 | for item in data | 104 | for item in data |
| 105 | ] | ||
| 106 | |||
| 107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | ||
| 108 | if len(self.keyword_filter) == 0: | ||
| 109 | return items | ||
| 110 | |||
| 111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | ||
| 112 | |||
| 113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | ||
| 114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | ||
| 115 | |||
| 116 | return [ | ||
| 117 | CSVDataItem( | ||
| 118 | item.instance_image_path, | ||
| 119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | ||
| 120 | item.prompt, | ||
| 121 | item.nprompt | ||
| 122 | ) | ||
| 123 | for item in items | ||
| 103 | for i in range(image_multiplier) | 124 | for i in range(image_multiplier) |
| 104 | ] | 125 | ] |
| 105 | 126 | ||
| 106 | def prepare_data(self): | 127 | def prepare_data(self): |
| 107 | with open(self.data_file, 'rt') as f: | 128 | with open(self.data_file, 'rt') as f: |
| 108 | metadata = json.load(f) | 129 | metadata = json.load(f) |
| 109 | template = metadata["template"] if "template" in metadata else {} | 130 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| 110 | items = metadata["items"] if "items" in metadata else [] | 131 | items = metadata["items"] if "items" in metadata else [] |
| 111 | 132 | ||
| 112 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 133 | items = [item for item in items if not "skip" in item or item["skip"] != True] |
| 134 | items = self.prepare_items(template, items) | ||
| 135 | items = self.filter_items(items) | ||
| 136 | |||
| 113 | num_images = len(items) | 137 | num_images = len(items) |
| 114 | 138 | ||
| 115 | valid_set_size = int(num_images * 0.1) | 139 | valid_set_size = int(num_images * 0.1) |
| @@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 120 | 144 | ||
| 121 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 145 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
| 122 | 146 | ||
| 123 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) | 147 | self.data_train = self.pad_items(data_train, self.num_class_images) |
| 124 | self.data_val = self.prepare_subdata(template, data_val) | 148 | self.data_val = self.pad_items(data_val) |
| 125 | 149 | ||
| 126 | def setup(self, stage=None): | 150 | def setup(self, stage=None): |
| 127 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 151 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
diff --git a/dreambooth.py b/dreambooth.py index 31416e9..5521b21 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -57,6 +57,11 @@ def parse_args(): | |||
| 57 | help="A folder containing the training data." | 57 | help="A folder containing the training data." |
| 58 | ) | 58 | ) |
| 59 | parser.add_argument( | 59 | parser.add_argument( |
| 60 | "--train_data_template", | ||
| 61 | type=str, | ||
| 62 | default="template", | ||
| 63 | ) | ||
| 64 | parser.add_argument( | ||
| 60 | "--instance_identifier", | 65 | "--instance_identifier", |
| 61 | type=str, | 66 | type=str, |
| 62 | default=None, | 67 | default=None, |
| @@ -768,6 +773,7 @@ def main(): | |||
| 768 | repeats=args.repeats, | 773 | repeats=args.repeats, |
| 769 | dropout=args.tag_dropout, | 774 | dropout=args.tag_dropout, |
| 770 | center_crop=args.center_crop, | 775 | center_crop=args.center_crop, |
| 776 | template_key=args.train_data_template, | ||
| 771 | valid_set_size=args.valid_set_size, | 777 | valid_set_size=args.valid_set_size, |
| 772 | num_workers=args.dataloader_num_workers, | 778 | num_workers=args.dataloader_num_workers, |
| 773 | collate_fn=collate_fn | 779 | collate_fn=collate_fn |
diff --git a/textual_inversion.py b/textual_inversion.py index 19b8993..fd4a313 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -58,6 +58,11 @@ def parse_args(): | |||
| 58 | help="A CSV file containing the training data." | 58 | help="A CSV file containing the training data." |
| 59 | ) | 59 | ) |
| 60 | parser.add_argument( | 60 | parser.add_argument( |
| 61 | "--train_data_template", | ||
| 62 | type=str, | ||
| 63 | default="template", | ||
| 64 | ) | ||
| 65 | parser.add_argument( | ||
| 61 | "--instance_identifier", | 66 | "--instance_identifier", |
| 62 | type=str, | 67 | type=str, |
| 63 | default=None, | 68 | default=None, |
| @@ -121,7 +126,7 @@ def parse_args(): | |||
| 121 | parser.add_argument( | 126 | parser.add_argument( |
| 122 | "--tag_dropout", | 127 | "--tag_dropout", |
| 123 | type=float, | 128 | type=float, |
| 124 | default=0.1, | 129 | default=0, |
| 125 | help="Tag dropout probability.", | 130 | help="Tag dropout probability.", |
| 126 | ) | 131 | ) |
| 127 | parser.add_argument( | 132 | parser.add_argument( |
| @@ -170,7 +175,7 @@ def parse_args(): | |||
| 170 | parser.add_argument( | 175 | parser.add_argument( |
| 171 | "--lr_scheduler", | 176 | "--lr_scheduler", |
| 172 | type=str, | 177 | type=str, |
| 173 | default="constant_with_warmup", | 178 | default="one_cycle", |
| 174 | help=( | 179 | help=( |
| 175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 180 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 176 | ' "constant", "constant_with_warmup", "one_cycle"]' | 181 | ' "constant", "constant_with_warmup", "one_cycle"]' |
| @@ -670,8 +675,10 @@ def main(): | |||
| 670 | repeats=args.repeats, | 675 | repeats=args.repeats, |
| 671 | dropout=args.tag_dropout, | 676 | dropout=args.tag_dropout, |
| 672 | center_crop=args.center_crop, | 677 | center_crop=args.center_crop, |
| 678 | template_key=args.train_data_template, | ||
| 673 | valid_set_size=args.valid_set_size, | 679 | valid_set_size=args.valid_set_size, |
| 674 | num_workers=args.dataloader_num_workers, | 680 | num_workers=args.dataloader_num_workers, |
| 681 | keyword_filter=args.placeholder_token, | ||
| 675 | collate_fn=collate_fn | 682 | collate_fn=collate_fn |
| 676 | ) | 683 | ) |
| 677 | 684 | ||
| @@ -740,7 +747,7 @@ def main(): | |||
| 740 | num_warmup_steps=warmup_steps, | 747 | num_warmup_steps=warmup_steps, |
| 741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 748 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 742 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 749 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
| 743 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 750 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), |
| 744 | ) | 751 | ) |
| 745 | else: | 752 | else: |
| 746 | lr_scheduler = get_scheduler( | 753 | lr_scheduler = get_scheduler( |
