diff options
-rw-r--r-- | data/csv.py | 38 | ||||
-rw-r--r-- | dreambooth.py | 6 | ||||
-rw-r--r-- | textual_inversion.py | 13 |
3 files changed, 47 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index 9125212..9c3c3f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
54 | dropout: float = 0, | 54 | dropout: float = 0, |
55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
57 | template_key: str = "template", | ||
57 | valid_set_size: Optional[int] = None, | 58 | valid_set_size: Optional[int] = None, |
58 | generator: Optional[torch.Generator] = None, | 59 | generator: Optional[torch.Generator] = None, |
60 | keyword_filter: list[str] = [], | ||
59 | collate_fn=None, | 61 | collate_fn=None, |
60 | num_workers: int = 0 | 62 | num_workers: int = 0 |
61 | ): | 63 | ): |
@@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule): | |||
78 | self.repeats = repeats | 80 | self.repeats = repeats |
79 | self.dropout = dropout | 81 | self.dropout = dropout |
80 | self.center_crop = center_crop | 82 | self.center_crop = center_crop |
83 | self.template_key = template_key | ||
81 | self.interpolation = interpolation | 84 | self.interpolation = interpolation |
82 | self.valid_set_size = valid_set_size | 85 | self.valid_set_size = valid_set_size |
83 | self.generator = generator | 86 | self.generator = generator |
87 | self.keyword_filter = keyword_filter | ||
84 | self.collate_fn = collate_fn | 88 | self.collate_fn = collate_fn |
85 | self.num_workers = num_workers | 89 | self.num_workers = num_workers |
86 | self.batch_size = batch_size | 90 | self.batch_size = batch_size |
87 | 91 | ||
88 | def prepare_subdata(self, template, data, num_class_images=1): | 92 | def prepare_items(self, template, data) -> list[CSVDataItem]: |
89 | image = template["image"] if "image" in template else "{}" | 93 | image = template["image"] if "image" in template else "{}" |
90 | prompt = template["prompt"] if "prompt" in template else "{content}" | 94 | prompt = template["prompt"] if "prompt" in template else "{content}" |
91 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 95 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
92 | 96 | ||
93 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
94 | |||
95 | return [ | 97 | return [ |
96 | CSVDataItem( | 98 | CSVDataItem( |
97 | self.data_root.joinpath(image.format(item["image"])), | 99 | self.data_root.joinpath(image.format(item["image"])), |
98 | self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), | 100 | None, |
99 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
100 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) |
101 | ) | 103 | ) |
102 | for item in data | 104 | for item in data |
105 | ] | ||
106 | |||
107 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | ||
108 | if len(self.keyword_filter) == 0: | ||
109 | return items | ||
110 | |||
111 | return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)] | ||
112 | |||
113 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | ||
114 | image_multiplier = max(math.ceil(num_class_images / len(items)), 1) | ||
115 | |||
116 | return [ | ||
117 | CSVDataItem( | ||
118 | item.instance_image_path, | ||
119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | ||
120 | item.prompt, | ||
121 | item.nprompt | ||
122 | ) | ||
123 | for item in items | ||
103 | for i in range(image_multiplier) | 124 | for i in range(image_multiplier) |
104 | ] | 125 | ] |
105 | 126 | ||
106 | def prepare_data(self): | 127 | def prepare_data(self): |
107 | with open(self.data_file, 'rt') as f: | 128 | with open(self.data_file, 'rt') as f: |
108 | metadata = json.load(f) | 129 | metadata = json.load(f) |
109 | template = metadata["template"] if "template" in metadata else {} | 130 | template = metadata[self.template_key] if self.template_key in metadata else {} |
110 | items = metadata["items"] if "items" in metadata else [] | 131 | items = metadata["items"] if "items" in metadata else [] |
111 | 132 | ||
112 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 133 | items = [item for item in items if not "skip" in item or item["skip"] != True] |
134 | items = self.prepare_items(template, items) | ||
135 | items = self.filter_items(items) | ||
136 | |||
113 | num_images = len(items) | 137 | num_images = len(items) |
114 | 138 | ||
115 | valid_set_size = int(num_images * 0.1) | 139 | valid_set_size = int(num_images * 0.1) |
@@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
120 | 144 | ||
121 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 145 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) |
122 | 146 | ||
123 | self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) | 147 | self.data_train = self.pad_items(data_train, self.num_class_images) |
124 | self.data_val = self.prepare_subdata(template, data_val) | 148 | self.data_val = self.pad_items(data_val) |
125 | 149 | ||
126 | def setup(self, stage=None): | 150 | def setup(self, stage=None): |
127 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 151 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, |
diff --git a/dreambooth.py b/dreambooth.py index 31416e9..5521b21 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -57,6 +57,11 @@ def parse_args(): | |||
57 | help="A folder containing the training data." | 57 | help="A folder containing the training data." |
58 | ) | 58 | ) |
59 | parser.add_argument( | 59 | parser.add_argument( |
60 | "--train_data_template", | ||
61 | type=str, | ||
62 | default="template", | ||
63 | ) | ||
64 | parser.add_argument( | ||
60 | "--instance_identifier", | 65 | "--instance_identifier", |
61 | type=str, | 66 | type=str, |
62 | default=None, | 67 | default=None, |
@@ -768,6 +773,7 @@ def main(): | |||
768 | repeats=args.repeats, | 773 | repeats=args.repeats, |
769 | dropout=args.tag_dropout, | 774 | dropout=args.tag_dropout, |
770 | center_crop=args.center_crop, | 775 | center_crop=args.center_crop, |
776 | template_key=args.train_data_template, | ||
771 | valid_set_size=args.valid_set_size, | 777 | valid_set_size=args.valid_set_size, |
772 | num_workers=args.dataloader_num_workers, | 778 | num_workers=args.dataloader_num_workers, |
773 | collate_fn=collate_fn | 779 | collate_fn=collate_fn |
diff --git a/textual_inversion.py b/textual_inversion.py index 19b8993..fd4a313 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -58,6 +58,11 @@ def parse_args(): | |||
58 | help="A CSV file containing the training data." | 58 | help="A CSV file containing the training data." |
59 | ) | 59 | ) |
60 | parser.add_argument( | 60 | parser.add_argument( |
61 | "--train_data_template", | ||
62 | type=str, | ||
63 | default="template", | ||
64 | ) | ||
65 | parser.add_argument( | ||
61 | "--instance_identifier", | 66 | "--instance_identifier", |
62 | type=str, | 67 | type=str, |
63 | default=None, | 68 | default=None, |
@@ -121,7 +126,7 @@ def parse_args(): | |||
121 | parser.add_argument( | 126 | parser.add_argument( |
122 | "--tag_dropout", | 127 | "--tag_dropout", |
123 | type=float, | 128 | type=float, |
124 | default=0.1, | 129 | default=0, |
125 | help="Tag dropout probability.", | 130 | help="Tag dropout probability.", |
126 | ) | 131 | ) |
127 | parser.add_argument( | 132 | parser.add_argument( |
@@ -170,7 +175,7 @@ def parse_args(): | |||
170 | parser.add_argument( | 175 | parser.add_argument( |
171 | "--lr_scheduler", | 176 | "--lr_scheduler", |
172 | type=str, | 177 | type=str, |
173 | default="constant_with_warmup", | 178 | default="one_cycle", |
174 | help=( | 179 | help=( |
175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 180 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
176 | ' "constant", "constant_with_warmup", "one_cycle"]' | 181 | ' "constant", "constant_with_warmup", "one_cycle"]' |
@@ -670,8 +675,10 @@ def main(): | |||
670 | repeats=args.repeats, | 675 | repeats=args.repeats, |
671 | dropout=args.tag_dropout, | 676 | dropout=args.tag_dropout, |
672 | center_crop=args.center_crop, | 677 | center_crop=args.center_crop, |
678 | template_key=args.train_data_template, | ||
673 | valid_set_size=args.valid_set_size, | 679 | valid_set_size=args.valid_set_size, |
674 | num_workers=args.dataloader_num_workers, | 680 | num_workers=args.dataloader_num_workers, |
681 | keyword_filter=args.placeholder_token, | ||
675 | collate_fn=collate_fn | 682 | collate_fn=collate_fn |
676 | ) | 683 | ) |
677 | 684 | ||
@@ -740,7 +747,7 @@ def main(): | |||
740 | num_warmup_steps=warmup_steps, | 747 | num_warmup_steps=warmup_steps, |
741 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 748 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
742 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | 749 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( |
743 | ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), | 750 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), |
744 | ) | 751 | ) |
745 | else: | 752 | else: |
746 | lr_scheduler = get_scheduler( | 753 | lr_scheduler = get_scheduler( |