diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 14:42:40 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 14:42:40 +0100 |
commit | c16d5ba5ab1445386efcb87325557d4c784891e6 (patch) | |
tree | 85f521bc348ef50c13a706a6b1a6a78a06235722 | |
parent | Fixed aspect ratio bucketing; allow passing token IDs to pipeline (diff) | |
download | textual-inversion-diff-c16d5ba5ab1445386efcb87325557d4c784891e6.tar.gz textual-inversion-diff-c16d5ba5ab1445386efcb87325557d4c784891e6.tar.bz2 textual-inversion-diff-c16d5ba5ab1445386efcb87325557d4c784891e6.zip |
Cleanup
-rw-r--r-- | data/csv.py | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/data/csv.py b/data/csv.py index 289a64d..eaef5e6 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -40,15 +40,22 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
40 | return {"content": prompt} if isinstance(prompt, str) else prompt | 40 | return {"content": prompt} if isinstance(prompt, str) else prompt |
41 | 41 | ||
42 | 42 | ||
43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | 43 | def generate_buckets( |
44 | items: list[str], | ||
45 | base_size: int, | ||
46 | step_size: int = 64, | ||
47 | num_buckets: int = 4, | ||
48 | progressive_buckets: bool = False, | ||
49 | return_tensor: bool = True | ||
50 | ): | ||
44 | bucket_items: list[int] = [] | 51 | bucket_items: list[int] = [] |
45 | bucket_assignments: list[int] = [] | 52 | bucket_assignments: list[int] = [] |
46 | buckets = [1.0] | 53 | buckets = [1.0] |
47 | 54 | ||
48 | for i in range(1, num_buckets + 1): | 55 | for i in range(1, num_buckets + 1): |
49 | s = size + i * 64 | 56 | s = base_size + i * step_size |
50 | buckets.append(s / size) | 57 | buckets.append(s / base_size) |
51 | buckets.append(size / s) | 58 | buckets.append(base_size / s) |
52 | 59 | ||
53 | buckets = torch.tensor(buckets) | 60 | buckets = torch.tensor(buckets) |
54 | bucket_indices = torch.arange(len(buckets)) | 61 | bucket_indices = torch.arange(len(buckets)) |
@@ -58,9 +65,9 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ | |||
58 | ratio = image.width / image.height | 65 | ratio = image.width / image.height |
59 | 66 | ||
60 | if ratio >= 1: | 67 | if ratio >= 1: |
61 | mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) | 68 | mask = torch.logical_and(buckets >= 1, buckets <= ratio) |
62 | else: | 69 | else: |
63 | mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) | 70 | mask = torch.logical_and(buckets <= 1, buckets >= ratio) |
64 | 71 | ||
65 | if not progressive_buckets: | 72 | if not progressive_buckets: |
66 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() | 73 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() |
@@ -73,7 +80,13 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ | |||
73 | bucket_items += [i] * len(indices) | 80 | bucket_items += [i] * len(indices) |
74 | bucket_assignments += indices | 81 | bucket_assignments += indices |
75 | 82 | ||
76 | return buckets.tolist(), bucket_items, bucket_assignments | 83 | if return_tensor: |
84 | bucket_items = torch.tensor(bucket_items) | ||
85 | bucket_assignments = torch.tensor(bucket_assignments) | ||
86 | else: | ||
87 | buckets = buckets.tolist() | ||
88 | |||
89 | return buckets, bucket_items, bucket_assignments | ||
77 | 90 | ||
78 | 91 | ||
79 | class VlpnDataItem(NamedTuple): | 92 | class VlpnDataItem(NamedTuple): |
@@ -256,17 +269,14 @@ class VlpnDataset(IterableDataset): | |||
256 | self.interpolation = interpolations[interpolation] | 269 | self.interpolation = interpolations[interpolation] |
257 | self.generator = generator | 270 | self.generator = generator |
258 | 271 | ||
259 | buckets, bucket_items, bucket_assignments = generate_buckets( | 272 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
260 | [item.instance_image_path for item in items], | 273 | [item.instance_image_path for item in items], |
261 | size, | 274 | base_size=size, |
262 | num_buckets, | 275 | num_buckets=num_buckets, |
263 | progressive_buckets | 276 | progressive_buckets=progressive_buckets, |
264 | ) | 277 | ) |
265 | 278 | ||
266 | self.buckets = torch.tensor(buckets) | 279 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
267 | self.bucket_items = torch.tensor(bucket_items) | ||
268 | self.bucket_assignments = torch.tensor(bucket_assignments) | ||
269 | self.bucket_item_range = torch.arange(len(bucket_items)) | ||
270 | 280 | ||
271 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 281 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
272 | 282 | ||