summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 15:05:39 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 15:05:39 +0100
commit6970adaff742ac89adb3d85c803689210dc030e2 (patch)
tree042eec1c77b800c3b64eff4b8cc40f0a7b153e4d
parentAdded progressive aspect ratio bucketing (diff)
downloadtextual-inversion-diff-6970adaff742ac89adb3d85c803689210dc030e2.tar.gz
textual-inversion-diff-6970adaff742ac89adb3d85c803689210dc030e2.tar.bz2
textual-inversion-diff-6970adaff742ac89adb3d85c803689210dc030e2.zip
Made aspect ratio bucketing configurable
-rw-r--r--data/csv.py33
-rw-r--r--train_ti.py13
-rw-r--r--training/util.py9
3 files changed, 37 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py
index 59d6d8d..654aec1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split
7from torchvision import transforms 7from torchvision import transforms
8from typing import Dict, NamedTuple, List, Optional, Union, Callable 8from typing import Dict, NamedTuple, List, Optional, Union, Callable
9 9
10import numpy as np
11
10from models.clip.prompt import PromptProcessor 12from models.clip.prompt import PromptProcessor
11from data.keywords import prompt_to_keywords, keywords_to_prompt 13from data.keywords import prompt_to_keywords, keywords_to_prompt
12 14
@@ -56,6 +58,8 @@ class VlpnDataModule():
56 class_subdir: str = "cls", 58 class_subdir: str = "cls",
57 num_class_images: int = 1, 59 num_class_images: int = 1,
58 size: int = 768, 60 size: int = 768,
61 num_aspect_ratio_buckets: int = 0,
62 progressive_aspect_ratio_buckets: bool = False,
59 repeats: int = 1, 63 repeats: int = 1,
60 dropout: float = 0, 64 dropout: float = 0,
61 interpolation: str = "bicubic", 65 interpolation: str = "bicubic",
@@ -80,6 +84,8 @@ class VlpnDataModule():
80 84
81 self.prompt_processor = prompt_processor 85 self.prompt_processor = prompt_processor
82 self.size = size 86 self.size = size
87 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets
88 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets
83 self.repeats = repeats 89 self.repeats = repeats
84 self.dropout = dropout 90 self.dropout = dropout
85 self.template_key = template_key 91 self.template_key = template_key
@@ -143,25 +149,32 @@ class VlpnDataModule():
143 def generate_buckets(self, items: list[VlpnDataItem]): 149 def generate_buckets(self, items: list[VlpnDataItem]):
144 buckets = [VlpnDataBucket(self.size, self.size)] 150 buckets = [VlpnDataBucket(self.size, self.size)]
145 151
146 for i in range(1, 5): 152 for i in range(1, self.num_aspect_ratio_buckets + 1):
147 s = self.size + i * 64 153 s = self.size + i * 64
148 buckets.append(VlpnDataBucket(s, self.size)) 154 buckets.append(VlpnDataBucket(s, self.size))
149 buckets.append(VlpnDataBucket(self.size, s)) 155 buckets.append(VlpnDataBucket(self.size, s))
150 156
157 buckets = np.array(buckets)
158 bucket_ratios = np.array([bucket.ratio for bucket in buckets])
159
151 for item in items: 160 for item in items:
152 image = get_image(item.instance_image_path) 161 image = get_image(item.instance_image_path)
153 ratio = image.width / image.height 162 ratio = image.width / image.height
154 163
155 if ratio >= 1: 164 if ratio >= 1:
156 candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] 165 mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio)
157 else: 166 else:
158 candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] 167 mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio)
159 168
160 for bucket in candidates: 169 if not self.progressive_aspect_ratio_buckets:
170 ratios = bucket_ratios.copy()
171 ratios[~mask] = math.inf
172 mask = [np.argmin(np.abs(ratios - ratio))]
173
174 for bucket in buckets[mask]:
161 bucket.items.append(item) 175 bucket.items.append(item)
162 176
163 buckets = [bucket for bucket in buckets if len(bucket.items) != 0] 177 return [bucket for bucket in buckets if len(bucket.items) != 0]
164 return buckets
165 178
166 def setup(self): 179 def setup(self):
167 with open(self.data_file, 'rt') as f: 180 with open(self.data_file, 'rt') as f:
@@ -192,7 +205,7 @@ class VlpnDataModule():
192 205
193 train_datasets = [ 206 train_datasets = [
194 VlpnDataset( 207 VlpnDataset(
195 bucket.items, self.prompt_processor, batch_size=self.batch_size, 208 bucket.items, self.prompt_processor,
196 width=bucket.width, height=bucket.height, interpolation=self.interpolation, 209 width=bucket.width, height=bucket.height, interpolation=self.interpolation,
197 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, 210 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout,
198 ) 211 )
@@ -200,7 +213,7 @@ class VlpnDataModule():
200 ] 213 ]
201 214
202 val_dataset = VlpnDataset( 215 val_dataset = VlpnDataset(
203 data_val, self.prompt_processor, batch_size=self.batch_size, 216 data_val, self.prompt_processor,
204 width=self.size, height=self.size, interpolation=self.interpolation, 217 width=self.size, height=self.size, interpolation=self.interpolation,
205 ) 218 )
206 219
@@ -223,7 +236,6 @@ class VlpnDataset(Dataset):
223 self, 236 self,
224 data: List[VlpnDataItem], 237 data: List[VlpnDataItem],
225 prompt_processor: PromptProcessor, 238 prompt_processor: PromptProcessor,
226 batch_size: int = 1,
227 num_class_images: int = 0, 239 num_class_images: int = 0,
228 width: int = 768, 240 width: int = 768,
229 height: int = 768, 241 height: int = 768,
@@ -234,7 +246,6 @@ class VlpnDataset(Dataset):
234 246
235 self.data = data 247 self.data = data
236 self.prompt_processor = prompt_processor 248 self.prompt_processor = prompt_processor
237 self.batch_size = batch_size
238 self.num_class_images = num_class_images 249 self.num_class_images = num_class_images
239 self.dropout = dropout 250 self.dropout = dropout
240 251
@@ -258,7 +269,7 @@ class VlpnDataset(Dataset):
258 ) 269 )
259 270
260 def __len__(self): 271 def __len__(self):
261 return math.ceil(self._length / self.batch_size) * self.batch_size 272 return self._length
262 273
263 def get_example(self, i): 274 def get_example(self, i):
264 item = self.data[i % self.num_instance_images] 275 item = self.data[i % self.num_instance_images]
diff --git a/train_ti.py b/train_ti.py
index 89c6672..38c9755 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -146,6 +146,17 @@ def parse_args():
146 ), 146 ),
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--num_aspect_ratio_buckets",
150 type=int,
151 default=4,
152 help="Number of buckets in either direction (adds 64 pixels per step).",
153 )
154 parser.add_argument(
155 "--progressive_aspect_ratio_buckets",
156 action="store_true",
157 help="Include images in smaller buckets as well.",
158 )
159 parser.add_argument(
149 "--tag_dropout", 160 "--tag_dropout",
150 type=float, 161 type=float,
151 default=0.1, 162 default=0.1,
@@ -710,6 +721,8 @@ def main():
710 class_subdir=args.class_image_dir, 721 class_subdir=args.class_image_dir,
711 num_class_images=args.num_class_images, 722 num_class_images=args.num_class_images,
712 size=args.resolution, 723 size=args.resolution,
724 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets,
725 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets,
713 repeats=args.repeats, 726 repeats=args.repeats,
714 dropout=args.tag_dropout, 727 dropout=args.tag_dropout,
715 template_key=args.train_data_template, 728 template_key=args.train_data_template,
diff --git a/training/util.py b/training/util.py
index 6f42228..2b7f71d 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,6 +1,7 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4import itertools
4from typing import Iterable, Optional 5from typing import Iterable, Optional
5from contextlib import contextmanager 6from contextlib import contextmanager
6 7
@@ -71,13 +72,7 @@ class CheckpointerBase:
71 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 72 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
72 file_path.parent.mkdir(parents=True, exist_ok=True) 73 file_path.parent.mkdir(parents=True, exist_ok=True)
73 74
74 data_enum = enumerate(data) 75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
75
76 batches = [
77 batch
78 for j, batch in data_enum
79 if j * data.batch_size < self.sample_batch_size * self.sample_batches
80 ]
81 prompts = [ 76 prompts = [
82 prompt 77 prompt
83 for batch in batches 78 for batch in batches