summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py38
-rw-r--r--dreambooth.py6
-rw-r--r--textual_inversion.py13
3 files changed, 47 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py
index 9125212..9c3c3f8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -54,8 +54,10 @@ class CSVDataModule(pl.LightningDataModule):
54 dropout: float = 0, 54 dropout: float = 0,
55 interpolation: str = "bicubic", 55 interpolation: str = "bicubic",
56 center_crop: bool = False, 56 center_crop: bool = False,
57 template_key: str = "template",
57 valid_set_size: Optional[int] = None, 58 valid_set_size: Optional[int] = None,
58 generator: Optional[torch.Generator] = None, 59 generator: Optional[torch.Generator] = None,
60 keyword_filter: list[str] = [],
59 collate_fn=None, 61 collate_fn=None,
60 num_workers: int = 0 62 num_workers: int = 0
61 ): 63 ):
@@ -78,38 +80,60 @@ class CSVDataModule(pl.LightningDataModule):
78 self.repeats = repeats 80 self.repeats = repeats
79 self.dropout = dropout 81 self.dropout = dropout
80 self.center_crop = center_crop 82 self.center_crop = center_crop
83 self.template_key = template_key
81 self.interpolation = interpolation 84 self.interpolation = interpolation
82 self.valid_set_size = valid_set_size 85 self.valid_set_size = valid_set_size
83 self.generator = generator 86 self.generator = generator
87 self.keyword_filter = keyword_filter
84 self.collate_fn = collate_fn 88 self.collate_fn = collate_fn
85 self.num_workers = num_workers 89 self.num_workers = num_workers
86 self.batch_size = batch_size 90 self.batch_size = batch_size
87 91
88 def prepare_subdata(self, template, data, num_class_images=1): 92 def prepare_items(self, template, data) -> list[CSVDataItem]:
89 image = template["image"] if "image" in template else "{}" 93 image = template["image"] if "image" in template else "{}"
90 prompt = template["prompt"] if "prompt" in template else "{content}" 94 prompt = template["prompt"] if "prompt" in template else "{content}"
91 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 95 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
92 96
93 image_multiplier = max(math.ceil(num_class_images / len(data)), 1)
94
95 return [ 97 return [
96 CSVDataItem( 98 CSVDataItem(
97 self.data_root.joinpath(image.format(item["image"])), 99 self.data_root.joinpath(image.format(item["image"])),
98 self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"), 100 None,
99 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 101 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
100 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) 102 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else ""))
101 ) 103 )
102 for item in data 104 for item in data
105 ]
106
107 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]:
108 if len(self.keyword_filter) == 0:
109 return items
110
111 return [item for item in items if any(keyword in item.prompt for keyword in self.keyword_filter)]
112
113 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]:
114 image_multiplier = max(math.ceil(num_class_images / len(items)), 1)
115
116 return [
117 CSVDataItem(
118 item.instance_image_path,
119 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
120 item.prompt,
121 item.nprompt
122 )
123 for item in items
103 for i in range(image_multiplier) 124 for i in range(image_multiplier)
104 ] 125 ]
105 126
106 def prepare_data(self): 127 def prepare_data(self):
107 with open(self.data_file, 'rt') as f: 128 with open(self.data_file, 'rt') as f:
108 metadata = json.load(f) 129 metadata = json.load(f)
109 template = metadata["template"] if "template" in metadata else {} 130 template = metadata[self.template_key] if self.template_key in metadata else {}
110 items = metadata["items"] if "items" in metadata else [] 131 items = metadata["items"] if "items" in metadata else []
111 132
112 items = [item for item in items if not "skip" in item or item["skip"] != True] 133 items = [item for item in items if not "skip" in item or item["skip"] != True]
134 items = self.prepare_items(template, items)
135 items = self.filter_items(items)
136
113 num_images = len(items) 137 num_images = len(items)
114 138
115 valid_set_size = int(num_images * 0.1) 139 valid_set_size = int(num_images * 0.1)
@@ -120,8 +144,8 @@ class CSVDataModule(pl.LightningDataModule):
120 144
121 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) 145 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator)
122 146
123 self.data_train = self.prepare_subdata(template, data_train, self.num_class_images) 147 self.data_train = self.pad_items(data_train, self.num_class_images)
124 self.data_val = self.prepare_subdata(template, data_val) 148 self.data_val = self.pad_items(data_val)
125 149
126 def setup(self, stage=None): 150 def setup(self, stage=None):
127 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 151 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
diff --git a/dreambooth.py b/dreambooth.py
index 31416e9..5521b21 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -57,6 +57,11 @@ def parse_args():
57 help="A folder containing the training data." 57 help="A folder containing the training data."
58 ) 58 )
59 parser.add_argument( 59 parser.add_argument(
60 "--train_data_template",
61 type=str,
62 default="template",
63 )
64 parser.add_argument(
60 "--instance_identifier", 65 "--instance_identifier",
61 type=str, 66 type=str,
62 default=None, 67 default=None,
@@ -768,6 +773,7 @@ def main():
768 repeats=args.repeats, 773 repeats=args.repeats,
769 dropout=args.tag_dropout, 774 dropout=args.tag_dropout,
770 center_crop=args.center_crop, 775 center_crop=args.center_crop,
776 template_key=args.train_data_template,
771 valid_set_size=args.valid_set_size, 777 valid_set_size=args.valid_set_size,
772 num_workers=args.dataloader_num_workers, 778 num_workers=args.dataloader_num_workers,
773 collate_fn=collate_fn 779 collate_fn=collate_fn
diff --git a/textual_inversion.py b/textual_inversion.py
index 19b8993..fd4a313 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -58,6 +58,11 @@ def parse_args():
58 help="A CSV file containing the training data." 58 help="A CSV file containing the training data."
59 ) 59 )
60 parser.add_argument( 60 parser.add_argument(
61 "--train_data_template",
62 type=str,
63 default="template",
64 )
65 parser.add_argument(
61 "--instance_identifier", 66 "--instance_identifier",
62 type=str, 67 type=str,
63 default=None, 68 default=None,
@@ -121,7 +126,7 @@ def parse_args():
121 parser.add_argument( 126 parser.add_argument(
122 "--tag_dropout", 127 "--tag_dropout",
123 type=float, 128 type=float,
124 default=0.1, 129 default=0,
125 help="Tag dropout probability.", 130 help="Tag dropout probability.",
126 ) 131 )
127 parser.add_argument( 132 parser.add_argument(
@@ -170,7 +175,7 @@ def parse_args():
170 parser.add_argument( 175 parser.add_argument(
171 "--lr_scheduler", 176 "--lr_scheduler",
172 type=str, 177 type=str,
173 default="constant_with_warmup", 178 default="one_cycle",
174 help=( 179 help=(
175 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 180 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
176 ' "constant", "constant_with_warmup", "one_cycle"]' 181 ' "constant", "constant_with_warmup", "one_cycle"]'
@@ -670,8 +675,10 @@ def main():
670 repeats=args.repeats, 675 repeats=args.repeats,
671 dropout=args.tag_dropout, 676 dropout=args.tag_dropout,
672 center_crop=args.center_crop, 677 center_crop=args.center_crop,
678 template_key=args.train_data_template,
673 valid_set_size=args.valid_set_size, 679 valid_set_size=args.valid_set_size,
674 num_workers=args.dataloader_num_workers, 680 num_workers=args.dataloader_num_workers,
681 keyword_filter=args.placeholder_token,
675 collate_fn=collate_fn 682 collate_fn=collate_fn
676 ) 683 )
677 684
@@ -740,7 +747,7 @@ def main():
740 num_warmup_steps=warmup_steps, 747 num_warmup_steps=warmup_steps,
741 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 748 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
742 num_cycles=args.lr_cycles or math.ceil(math.sqrt( 749 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
743 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), 750 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
744 ) 751 )
745 else: 752 else:
746 lr_scheduler = get_scheduler( 753 lr_scheduler = get_scheduler(