summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 20:33:04 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 20:33:04 +0100
commitecb12378da48fc3a17539d5cc33edc561cf8a426 (patch)
tree30517efe41d557a4c1f2661e80e4c0b87e807048
parentFixed aspect ratio bucketing (diff)
downloadtextual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.gz
textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.tar.bz2
textual-inversion-diff-ecb12378da48fc3a17539d5cc33edc561cf8a426.zip
Improved aspect ratio bucketing
-rw-r--r--data/csv.py22
-rw-r--r--train_dreambooth.py27
-rw-r--r--train_ti.py16
3 files changed, 61 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py
index 7527b7d..55a1988 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -44,18 +44,25 @@ def generate_buckets(
44 items: list[str], 44 items: list[str],
45 base_size: int, 45 base_size: int,
46 step_size: int = 64, 46 step_size: int = 64,
47 max_pixels: Optional[int] = None,
47 num_buckets: int = 4, 48 num_buckets: int = 4,
48 progressive_buckets: bool = False, 49 progressive_buckets: bool = False,
49 return_tensor: bool = True 50 return_tensor: bool = True
50): 51):
52 if max_pixels is None:
53 max_pixels = (base_size + step_size) ** 2
54
55 max_pixels = max(max_pixels, base_size * base_size)
56
51 bucket_items: list[int] = [] 57 bucket_items: list[int] = []
52 bucket_assignments: list[int] = [] 58 bucket_assignments: list[int] = []
53 buckets = [1.0] 59 buckets = [1.0]
54 60
55 for i in range(1, num_buckets + 1): 61 for i in range(1, num_buckets + 1):
56 s = base_size + i * step_size 62 long_side = base_size + i * step_size
57 buckets.append(s / base_size) 63 short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size)
58 buckets.append(base_size / s) 64 buckets.append(long_side / short_side)
65 buckets.append(short_side / long_side)
59 66
60 buckets = torch.tensor(buckets) 67 buckets = torch.tensor(buckets)
61 bucket_indices = torch.arange(len(buckets)) 68 bucket_indices = torch.arange(len(buckets))
@@ -110,6 +117,8 @@ class VlpnDataModule():
110 num_class_images: int = 1, 117 num_class_images: int = 1,
111 size: int = 768, 118 size: int = 768,
112 num_buckets: int = 0, 119 num_buckets: int = 0,
120 bucket_step_size: int = 64,
121 max_pixels_per_bucket: Optional[int] = None,
113 progressive_buckets: bool = False, 122 progressive_buckets: bool = False,
114 dropout: float = 0, 123 dropout: float = 0,
115 interpolation: str = "bicubic", 124 interpolation: str = "bicubic",
@@ -135,6 +144,8 @@ class VlpnDataModule():
135 self.prompt_processor = prompt_processor 144 self.prompt_processor = prompt_processor
136 self.size = size 145 self.size = size
137 self.num_buckets = num_buckets 146 self.num_buckets = num_buckets
147 self.bucket_step_size = bucket_step_size
148 self.max_pixels_per_bucket = max_pixels_per_bucket
138 self.progressive_buckets = progressive_buckets 149 self.progressive_buckets = progressive_buckets
139 self.dropout = dropout 150 self.dropout = dropout
140 self.template_key = template_key 151 self.template_key = template_key
@@ -223,6 +234,7 @@ class VlpnDataModule():
223 train_dataset = VlpnDataset( 234 train_dataset = VlpnDataset(
224 self.data_train, self.prompt_processor, 235 self.data_train, self.prompt_processor,
225 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 236 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
237 bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket,
226 batch_size=self.batch_size, generator=generator, 238 batch_size=self.batch_size, generator=generator,
227 size=self.size, interpolation=self.interpolation, 239 size=self.size, interpolation=self.interpolation,
228 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, 240 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True,
@@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset):
251 items: list[VlpnDataItem], 263 items: list[VlpnDataItem],
252 prompt_processor: PromptProcessor, 264 prompt_processor: PromptProcessor,
253 num_buckets: int = 1, 265 num_buckets: int = 1,
266 bucket_step_size: int = 64,
267 max_pixels_per_bucket: Optional[int] = None,
254 progressive_buckets: bool = False, 268 progressive_buckets: bool = False,
255 batch_size: int = 1, 269 batch_size: int = 1,
256 num_class_images: int = 0, 270 num_class_images: int = 0,
@@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset):
274 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( 288 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
275 [item.instance_image_path for item in items], 289 [item.instance_image_path for item in items],
276 base_size=size, 290 base_size=size,
291 step_size=bucket_step_size,
277 num_buckets=num_buckets, 292 num_buckets=num_buckets,
293 max_pixels=max_pixels_per_bucket,
278 progressive_buckets=progressive_buckets, 294 progressive_buckets=progressive_buckets,
279 ) 295 )
280 296
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 79eede6..d396249 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -104,6 +104,29 @@ def parse_args():
104 help="Number of epochs the text encoder will be trained." 104 help="Number of epochs the text encoder will be trained."
105 ) 105 )
106 parser.add_argument( 106 parser.add_argument(
107 "--num_buckets",
108 type=int,
109 default=4,
110 help="Number of aspect ratio buckets in either direction.",
111 )
112 parser.add_argument(
113 "--progressive_buckets",
114 action="store_true",
115 help="Include images in smaller buckets as well.",
116 )
117 parser.add_argument(
118 "--bucket_step_size",
119 type=int,
120 default=64,
121 help="Step size between buckets.",
122 )
123 parser.add_argument(
124 "--bucket_max_pixels",
125 type=int,
126 default=None,
127 help="Maximum pixels per bucket.",
128 )
129 parser.add_argument(
107 "--tag_dropout", 130 "--tag_dropout",
108 type=float, 131 type=float,
109 default=0.1, 132 default=0.1,
@@ -734,6 +757,10 @@ def main():
734 class_subdir=args.class_image_dir, 757 class_subdir=args.class_image_dir,
735 num_class_images=args.num_class_images, 758 num_class_images=args.num_class_images,
736 size=args.resolution, 759 size=args.resolution,
760 num_buckets=args.num_buckets,
761 progressive_buckets=args.progressive_buckets,
762 bucket_step_size=args.bucket_step_size,
763 bucket_max_pixels=args.bucket_max_pixels,
737 dropout=args.tag_dropout, 764 dropout=args.tag_dropout,
738 template_key=args.train_data_template, 765 template_key=args.train_data_template,
739 valid_set_size=args.valid_set_size, 766 valid_set_size=args.valid_set_size,
diff --git a/train_ti.py b/train_ti.py
index 323ef10..eb0b8b6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -143,7 +143,7 @@ def parse_args():
143 "--num_buckets", 143 "--num_buckets",
144 type=int, 144 type=int,
145 default=4, 145 default=4,
146 help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", 146 help="Number of aspect ratio buckets in either direction.",
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--progressive_buckets", 149 "--progressive_buckets",
@@ -151,6 +151,18 @@ def parse_args():
151 help="Include images in smaller buckets as well.", 151 help="Include images in smaller buckets as well.",
152 ) 152 )
153 parser.add_argument( 153 parser.add_argument(
154 "--bucket_step_size",
155 type=int,
156 default=64,
157 help="Step size between buckets.",
158 )
159 parser.add_argument(
160 "--bucket_max_pixels",
161 type=int,
162 default=None,
163 help="Maximum pixels per bucket.",
164 )
165 parser.add_argument(
154 "--tag_dropout", 166 "--tag_dropout",
155 type=float, 167 type=float,
156 default=0.1, 168 default=0.1,
@@ -718,6 +730,8 @@ def main():
718 size=args.resolution, 730 size=args.resolution,
719 num_buckets=args.num_buckets, 731 num_buckets=args.num_buckets,
720 progressive_buckets=args.progressive_buckets, 732 progressive_buckets=args.progressive_buckets,
733 bucket_step_size=args.bucket_step_size,
734 bucket_max_pixels=args.bucket_max_pixels,
721 dropout=args.tag_dropout, 735 dropout=args.tag_dropout,
722 template_key=args.train_data_template, 736 template_key=args.train_data_template,
723 valid_set_size=args.valid_set_size, 737 valid_set_size=args.valid_set_size,