diff options
-rw-r--r-- | data/csv.py | 33 | ||||
-rw-r--r-- | train_ti.py | 13 | ||||
-rw-r--r-- | training/util.py | 9 |
3 files changed, 37 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index 59d6d8d..654aec1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split | |||
7 | from torchvision import transforms | 7 | from torchvision import transforms |
8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
9 | 9 | ||
10 | import numpy as np | ||
11 | |||
10 | from models.clip.prompt import PromptProcessor | 12 | from models.clip.prompt import PromptProcessor |
11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 13 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
12 | 14 | ||
@@ -56,6 +58,8 @@ class VlpnDataModule(): | |||
56 | class_subdir: str = "cls", | 58 | class_subdir: str = "cls", |
57 | num_class_images: int = 1, | 59 | num_class_images: int = 1, |
58 | size: int = 768, | 60 | size: int = 768, |
61 | num_aspect_ratio_buckets: int = 0, | ||
62 | progressive_aspect_ratio_buckets: bool = False, | ||
59 | repeats: int = 1, | 63 | repeats: int = 1, |
60 | dropout: float = 0, | 64 | dropout: float = 0, |
61 | interpolation: str = "bicubic", | 65 | interpolation: str = "bicubic", |
@@ -80,6 +84,8 @@ class VlpnDataModule(): | |||
80 | 84 | ||
81 | self.prompt_processor = prompt_processor | 85 | self.prompt_processor = prompt_processor |
82 | self.size = size | 86 | self.size = size |
87 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | ||
88 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | ||
83 | self.repeats = repeats | 89 | self.repeats = repeats |
84 | self.dropout = dropout | 90 | self.dropout = dropout |
85 | self.template_key = template_key | 91 | self.template_key = template_key |
@@ -143,25 +149,32 @@ class VlpnDataModule(): | |||
143 | def generate_buckets(self, items: list[VlpnDataItem]): | 149 | def generate_buckets(self, items: list[VlpnDataItem]): |
144 | buckets = [VlpnDataBucket(self.size, self.size)] | 150 | buckets = [VlpnDataBucket(self.size, self.size)] |
145 | 151 | ||
146 | for i in range(1, 5): | 152 | for i in range(1, self.num_aspect_ratio_buckets + 1): |
147 | s = self.size + i * 64 | 153 | s = self.size + i * 64 |
148 | buckets.append(VlpnDataBucket(s, self.size)) | 154 | buckets.append(VlpnDataBucket(s, self.size)) |
149 | buckets.append(VlpnDataBucket(self.size, s)) | 155 | buckets.append(VlpnDataBucket(self.size, s)) |
150 | 156 | ||
157 | buckets = np.array(buckets) | ||
158 | bucket_ratios = np.array([bucket.ratio for bucket in buckets]) | ||
159 | |||
151 | for item in items: | 160 | for item in items: |
152 | image = get_image(item.instance_image_path) | 161 | image = get_image(item.instance_image_path) |
153 | ratio = image.width / image.height | 162 | ratio = image.width / image.height |
154 | 163 | ||
155 | if ratio >= 1: | 164 | if ratio >= 1: |
156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | 165 | mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) |
157 | else: | 166 | else: |
158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | 167 | mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) |
159 | 168 | ||
160 | for bucket in candidates: | 169 | if not self.progressive_aspect_ratio_buckets: |
170 | ratios = bucket_ratios.copy() | ||
171 | ratios[~mask] = math.inf | ||
172 | mask = [np.argmin(np.abs(ratios - ratio))] | ||
173 | |||
174 | for bucket in buckets[mask]: | ||
161 | bucket.items.append(item) | 175 | bucket.items.append(item) |
162 | 176 | ||
163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | 177 | return [bucket for bucket in buckets if len(bucket.items) != 0] |
164 | return buckets | ||
165 | 178 | ||
166 | def setup(self): | 179 | def setup(self): |
167 | with open(self.data_file, 'rt') as f: | 180 | with open(self.data_file, 'rt') as f: |
@@ -192,7 +205,7 @@ class VlpnDataModule(): | |||
192 | 205 | ||
193 | train_datasets = [ | 206 | train_datasets = [ |
194 | VlpnDataset( | 207 | VlpnDataset( |
195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, | 208 | bucket.items, self.prompt_processor, |
196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 209 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, |
197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 210 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, |
198 | ) | 211 | ) |
@@ -200,7 +213,7 @@ class VlpnDataModule(): | |||
200 | ] | 213 | ] |
201 | 214 | ||
202 | val_dataset = VlpnDataset( | 215 | val_dataset = VlpnDataset( |
203 | data_val, self.prompt_processor, batch_size=self.batch_size, | 216 | data_val, self.prompt_processor, |
204 | width=self.size, height=self.size, interpolation=self.interpolation, | 217 | width=self.size, height=self.size, interpolation=self.interpolation, |
205 | ) | 218 | ) |
206 | 219 | ||
@@ -223,7 +236,6 @@ class VlpnDataset(Dataset): | |||
223 | self, | 236 | self, |
224 | data: List[VlpnDataItem], | 237 | data: List[VlpnDataItem], |
225 | prompt_processor: PromptProcessor, | 238 | prompt_processor: PromptProcessor, |
226 | batch_size: int = 1, | ||
227 | num_class_images: int = 0, | 239 | num_class_images: int = 0, |
228 | width: int = 768, | 240 | width: int = 768, |
229 | height: int = 768, | 241 | height: int = 768, |
@@ -234,7 +246,6 @@ class VlpnDataset(Dataset): | |||
234 | 246 | ||
235 | self.data = data | 247 | self.data = data |
236 | self.prompt_processor = prompt_processor | 248 | self.prompt_processor = prompt_processor |
237 | self.batch_size = batch_size | ||
238 | self.num_class_images = num_class_images | 249 | self.num_class_images = num_class_images |
239 | self.dropout = dropout | 250 | self.dropout = dropout |
240 | 251 | ||
@@ -258,7 +269,7 @@ class VlpnDataset(Dataset): | |||
258 | ) | 269 | ) |
259 | 270 | ||
260 | def __len__(self): | 271 | def __len__(self): |
261 | return math.ceil(self._length / self.batch_size) * self.batch_size | 272 | return self._length |
262 | 273 | ||
263 | def get_example(self, i): | 274 | def get_example(self, i): |
264 | item = self.data[i % self.num_instance_images] | 275 | item = self.data[i % self.num_instance_images] |
diff --git a/train_ti.py b/train_ti.py index 89c6672..38c9755 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -146,6 +146,17 @@ def parse_args(): | |||
146 | ), | 146 | ), |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--num_aspect_ratio_buckets", | ||
150 | type=int, | ||
151 | default=4, | ||
152 | help="Number of buckets in either direction (adds 64 pixels per step).", | ||
153 | ) | ||
154 | parser.add_argument( | ||
155 | "--progressive_aspect_ratio_buckets", | ||
156 | action="store_true", | ||
157 | help="Include images in smaller buckets as well.", | ||
158 | ) | ||
159 | parser.add_argument( | ||
149 | "--tag_dropout", | 160 | "--tag_dropout", |
150 | type=float, | 161 | type=float, |
151 | default=0.1, | 162 | default=0.1, |
@@ -710,6 +721,8 @@ def main(): | |||
710 | class_subdir=args.class_image_dir, | 721 | class_subdir=args.class_image_dir, |
711 | num_class_images=args.num_class_images, | 722 | num_class_images=args.num_class_images, |
712 | size=args.resolution, | 723 | size=args.resolution, |
724 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | ||
725 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | ||
713 | repeats=args.repeats, | 726 | repeats=args.repeats, |
714 | dropout=args.tag_dropout, | 727 | dropout=args.tag_dropout, |
715 | template_key=args.train_data_template, | 728 | template_key=args.train_data_template, |
diff --git a/training/util.py b/training/util.py index 6f42228..2b7f71d 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,6 +1,7 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | import itertools | ||
4 | from typing import Iterable, Optional | 5 | from typing import Iterable, Optional |
5 | from contextlib import contextmanager | 6 | from contextlib import contextmanager |
6 | 7 | ||
@@ -71,13 +72,7 @@ class CheckpointerBase: | |||
71 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 72 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
72 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
73 | 74 | ||
74 | data_enum = enumerate(data) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
75 | |||
76 | batches = [ | ||
77 | batch | ||
78 | for j, batch in data_enum | ||
79 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
80 | ] | ||
81 | prompts = [ | 76 | prompts = [ |
82 | prompt | 77 | prompt |
83 | for batch in batches | 78 | for batch in batches |