diff options
| -rw-r--r-- | data/csv.py | 33 | ||||
| -rw-r--r-- | train_ti.py | 13 | ||||
| -rw-r--r-- | training/util.py | 9 |
3 files changed, 37 insertions, 18 deletions
diff --git a/data/csv.py b/data/csv.py index 59d6d8d..654aec1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -7,6 +7,8 @@ from torch.utils.data import Dataset, DataLoader, random_split | |||
| 7 | from torchvision import transforms | 7 | from torchvision import transforms |
| 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable |
| 9 | 9 | ||
| 10 | import numpy as np | ||
| 11 | |||
| 10 | from models.clip.prompt import PromptProcessor | 12 | from models.clip.prompt import PromptProcessor |
| 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 13 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
| 12 | 14 | ||
| @@ -56,6 +58,8 @@ class VlpnDataModule(): | |||
| 56 | class_subdir: str = "cls", | 58 | class_subdir: str = "cls", |
| 57 | num_class_images: int = 1, | 59 | num_class_images: int = 1, |
| 58 | size: int = 768, | 60 | size: int = 768, |
| 61 | num_aspect_ratio_buckets: int = 0, | ||
| 62 | progressive_aspect_ratio_buckets: bool = False, | ||
| 59 | repeats: int = 1, | 63 | repeats: int = 1, |
| 60 | dropout: float = 0, | 64 | dropout: float = 0, |
| 61 | interpolation: str = "bicubic", | 65 | interpolation: str = "bicubic", |
| @@ -80,6 +84,8 @@ class VlpnDataModule(): | |||
| 80 | 84 | ||
| 81 | self.prompt_processor = prompt_processor | 85 | self.prompt_processor = prompt_processor |
| 82 | self.size = size | 86 | self.size = size |
| 87 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | ||
| 88 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | ||
| 83 | self.repeats = repeats | 89 | self.repeats = repeats |
| 84 | self.dropout = dropout | 90 | self.dropout = dropout |
| 85 | self.template_key = template_key | 91 | self.template_key = template_key |
| @@ -143,25 +149,32 @@ class VlpnDataModule(): | |||
| 143 | def generate_buckets(self, items: list[VlpnDataItem]): | 149 | def generate_buckets(self, items: list[VlpnDataItem]): |
| 144 | buckets = [VlpnDataBucket(self.size, self.size)] | 150 | buckets = [VlpnDataBucket(self.size, self.size)] |
| 145 | 151 | ||
| 146 | for i in range(1, 5): | 152 | for i in range(1, self.num_aspect_ratio_buckets + 1): |
| 147 | s = self.size + i * 64 | 153 | s = self.size + i * 64 |
| 148 | buckets.append(VlpnDataBucket(s, self.size)) | 154 | buckets.append(VlpnDataBucket(s, self.size)) |
| 149 | buckets.append(VlpnDataBucket(self.size, s)) | 155 | buckets.append(VlpnDataBucket(self.size, s)) |
| 150 | 156 | ||
| 157 | buckets = np.array(buckets) | ||
| 158 | bucket_ratios = np.array([bucket.ratio for bucket in buckets]) | ||
| 159 | |||
| 151 | for item in items: | 160 | for item in items: |
| 152 | image = get_image(item.instance_image_path) | 161 | image = get_image(item.instance_image_path) |
| 153 | ratio = image.width / image.height | 162 | ratio = image.width / image.height |
| 154 | 163 | ||
| 155 | if ratio >= 1: | 164 | if ratio >= 1: |
| 156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | 165 | mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) |
| 157 | else: | 166 | else: |
| 158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | 167 | mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) |
| 159 | 168 | ||
| 160 | for bucket in candidates: | 169 | if not self.progressive_aspect_ratio_buckets: |
| 170 | ratios = bucket_ratios.copy() | ||
| 171 | ratios[~mask] = math.inf | ||
| 172 | mask = [np.argmin(np.abs(ratios - ratio))] | ||
| 173 | |||
| 174 | for bucket in buckets[mask]: | ||
| 161 | bucket.items.append(item) | 175 | bucket.items.append(item) |
| 162 | 176 | ||
| 163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | 177 | return [bucket for bucket in buckets if len(bucket.items) != 0] |
| 164 | return buckets | ||
| 165 | 178 | ||
| 166 | def setup(self): | 179 | def setup(self): |
| 167 | with open(self.data_file, 'rt') as f: | 180 | with open(self.data_file, 'rt') as f: |
| @@ -192,7 +205,7 @@ class VlpnDataModule(): | |||
| 192 | 205 | ||
| 193 | train_datasets = [ | 206 | train_datasets = [ |
| 194 | VlpnDataset( | 207 | VlpnDataset( |
| 195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, | 208 | bucket.items, self.prompt_processor, |
| 196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 209 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, |
| 197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 210 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, |
| 198 | ) | 211 | ) |
| @@ -200,7 +213,7 @@ class VlpnDataModule(): | |||
| 200 | ] | 213 | ] |
| 201 | 214 | ||
| 202 | val_dataset = VlpnDataset( | 215 | val_dataset = VlpnDataset( |
| 203 | data_val, self.prompt_processor, batch_size=self.batch_size, | 216 | data_val, self.prompt_processor, |
| 204 | width=self.size, height=self.size, interpolation=self.interpolation, | 217 | width=self.size, height=self.size, interpolation=self.interpolation, |
| 205 | ) | 218 | ) |
| 206 | 219 | ||
| @@ -223,7 +236,6 @@ class VlpnDataset(Dataset): | |||
| 223 | self, | 236 | self, |
| 224 | data: List[VlpnDataItem], | 237 | data: List[VlpnDataItem], |
| 225 | prompt_processor: PromptProcessor, | 238 | prompt_processor: PromptProcessor, |
| 226 | batch_size: int = 1, | ||
| 227 | num_class_images: int = 0, | 239 | num_class_images: int = 0, |
| 228 | width: int = 768, | 240 | width: int = 768, |
| 229 | height: int = 768, | 241 | height: int = 768, |
| @@ -234,7 +246,6 @@ class VlpnDataset(Dataset): | |||
| 234 | 246 | ||
| 235 | self.data = data | 247 | self.data = data |
| 236 | self.prompt_processor = prompt_processor | 248 | self.prompt_processor = prompt_processor |
| 237 | self.batch_size = batch_size | ||
| 238 | self.num_class_images = num_class_images | 249 | self.num_class_images = num_class_images |
| 239 | self.dropout = dropout | 250 | self.dropout = dropout |
| 240 | 251 | ||
| @@ -258,7 +269,7 @@ class VlpnDataset(Dataset): | |||
| 258 | ) | 269 | ) |
| 259 | 270 | ||
| 260 | def __len__(self): | 271 | def __len__(self): |
| 261 | return math.ceil(self._length / self.batch_size) * self.batch_size | 272 | return self._length |
| 262 | 273 | ||
| 263 | def get_example(self, i): | 274 | def get_example(self, i): |
| 264 | item = self.data[i % self.num_instance_images] | 275 | item = self.data[i % self.num_instance_images] |
diff --git a/train_ti.py b/train_ti.py index 89c6672..38c9755 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -146,6 +146,17 @@ def parse_args(): | |||
| 146 | ), | 146 | ), |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--num_aspect_ratio_buckets", | ||
| 150 | type=int, | ||
| 151 | default=4, | ||
| 152 | help="Number of buckets in either direction (adds 64 pixels per step).", | ||
| 153 | ) | ||
| 154 | parser.add_argument( | ||
| 155 | "--progressive_aspect_ratio_buckets", | ||
| 156 | action="store_true", | ||
| 157 | help="Include images in smaller buckets as well.", | ||
| 158 | ) | ||
| 159 | parser.add_argument( | ||
| 149 | "--tag_dropout", | 160 | "--tag_dropout", |
| 150 | type=float, | 161 | type=float, |
| 151 | default=0.1, | 162 | default=0.1, |
| @@ -710,6 +721,8 @@ def main(): | |||
| 710 | class_subdir=args.class_image_dir, | 721 | class_subdir=args.class_image_dir, |
| 711 | num_class_images=args.num_class_images, | 722 | num_class_images=args.num_class_images, |
| 712 | size=args.resolution, | 723 | size=args.resolution, |
| 724 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | ||
| 725 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | ||
| 713 | repeats=args.repeats, | 726 | repeats=args.repeats, |
| 714 | dropout=args.tag_dropout, | 727 | dropout=args.tag_dropout, |
| 715 | template_key=args.train_data_template, | 728 | template_key=args.train_data_template, |
diff --git a/training/util.py b/training/util.py index 6f42228..2b7f71d 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | import itertools | ||
| 4 | from typing import Iterable, Optional | 5 | from typing import Iterable, Optional |
| 5 | from contextlib import contextmanager | 6 | from contextlib import contextmanager |
| 6 | 7 | ||
| @@ -71,13 +72,7 @@ class CheckpointerBase: | |||
| 71 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 72 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 72 | file_path.parent.mkdir(parents=True, exist_ok=True) | 73 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 73 | 74 | ||
| 74 | data_enum = enumerate(data) | 75 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) |
| 75 | |||
| 76 | batches = [ | ||
| 77 | batch | ||
| 78 | for j, batch in data_enum | ||
| 79 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 80 | ] | ||
| 81 | prompts = [ | 76 | prompts = [ |
| 82 | prompt | 77 | prompt |
| 83 | for batch in batches | 78 | for batch in batches |
