summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/csv.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 7527b7d..55a1988 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -44,18 +44,25 @@ def generate_buckets(
44 items: list[str], 44 items: list[str],
45 base_size: int, 45 base_size: int,
46 step_size: int = 64, 46 step_size: int = 64,
47 max_pixels: Optional[int] = None,
47 num_buckets: int = 4, 48 num_buckets: int = 4,
48 progressive_buckets: bool = False, 49 progressive_buckets: bool = False,
49 return_tensor: bool = True 50 return_tensor: bool = True
50): 51):
52 if max_pixels is None:
53 max_pixels = (base_size + step_size) ** 2
54
55 max_pixels = max(max_pixels, base_size * base_size)
56
51 bucket_items: list[int] = [] 57 bucket_items: list[int] = []
52 bucket_assignments: list[int] = [] 58 bucket_assignments: list[int] = []
53 buckets = [1.0] 59 buckets = [1.0]
54 60
55 for i in range(1, num_buckets + 1): 61 for i in range(1, num_buckets + 1):
56 s = base_size + i * step_size 62 long_side = base_size + i * step_size
57 buckets.append(s / base_size) 63 short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size)
58 buckets.append(base_size / s) 64 buckets.append(long_side / short_side)
65 buckets.append(short_side / long_side)
59 66
60 buckets = torch.tensor(buckets) 67 buckets = torch.tensor(buckets)
61 bucket_indices = torch.arange(len(buckets)) 68 bucket_indices = torch.arange(len(buckets))
@@ -110,6 +117,8 @@ class VlpnDataModule():
110 num_class_images: int = 1, 117 num_class_images: int = 1,
111 size: int = 768, 118 size: int = 768,
112 num_buckets: int = 0, 119 num_buckets: int = 0,
120 bucket_step_size: int = 64,
121 max_pixels_per_bucket: Optional[int] = None,
113 progressive_buckets: bool = False, 122 progressive_buckets: bool = False,
114 dropout: float = 0, 123 dropout: float = 0,
115 interpolation: str = "bicubic", 124 interpolation: str = "bicubic",
@@ -135,6 +144,8 @@ class VlpnDataModule():
135 self.prompt_processor = prompt_processor 144 self.prompt_processor = prompt_processor
136 self.size = size 145 self.size = size
137 self.num_buckets = num_buckets 146 self.num_buckets = num_buckets
147 self.bucket_step_size = bucket_step_size
148 self.max_pixels_per_bucket = max_pixels_per_bucket
138 self.progressive_buckets = progressive_buckets 149 self.progressive_buckets = progressive_buckets
139 self.dropout = dropout 150 self.dropout = dropout
140 self.template_key = template_key 151 self.template_key = template_key
@@ -223,6 +234,7 @@ class VlpnDataModule():
223 train_dataset = VlpnDataset( 234 train_dataset = VlpnDataset(
224 self.data_train, self.prompt_processor, 235 self.data_train, self.prompt_processor,
225 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 236 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
237 bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket,
226 batch_size=self.batch_size, generator=generator, 238 batch_size=self.batch_size, generator=generator,
227 size=self.size, interpolation=self.interpolation, 239 size=self.size, interpolation=self.interpolation,
228 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, 240 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True,
@@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset):
251 items: list[VlpnDataItem], 263 items: list[VlpnDataItem],
252 prompt_processor: PromptProcessor, 264 prompt_processor: PromptProcessor,
253 num_buckets: int = 1, 265 num_buckets: int = 1,
266 bucket_step_size: int = 64,
267 max_pixels_per_bucket: Optional[int] = None,
254 progressive_buckets: bool = False, 268 progressive_buckets: bool = False,
255 batch_size: int = 1, 269 batch_size: int = 1,
256 num_class_images: int = 0, 270 num_class_images: int = 0,
@@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset):
274 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( 288 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
275 [item.instance_image_path for item in items], 289 [item.instance_image_path for item in items],
276 base_size=size, 290 base_size=size,
291 step_size=bucket_step_size,
277 num_buckets=num_buckets, 292 num_buckets=num_buckets,
293 max_pixels=max_pixels_per_bucket,
278 progressive_buckets=progressive_buckets, 294 progressive_buckets=progressive_buckets,
279 ) 295 )
280 296