diff options
author | Volpeon <git@volpeon.ink> | 2023-01-08 20:36:17 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-08 20:36:17 +0100 |
commit | ed892a06ba7a231a84d47bd835fc625aa3f2c75c (patch) | |
tree | 4dc33b2b0bebcf6b7880a42e0518e4ae5eecd3ad | |
parent | Improved aspect ratio bucketing (diff) | |
download | textual-inversion-diff-ed892a06ba7a231a84d47bd835fc625aa3f2c75c.tar.gz textual-inversion-diff-ed892a06ba7a231a84d47bd835fc625aa3f2c75c.tar.bz2 textual-inversion-diff-ed892a06ba7a231a84d47bd835fc625aa3f2c75c.zip |
Fix
-rw-r--r-- | data/csv.py | 10 | ||||
-rw-r--r-- | train_ti.py | 2 |
2 files changed, 6 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py index 55a1988..d9f9db8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -118,7 +118,7 @@ class VlpnDataModule(): | |||
118 | size: int = 768, | 118 | size: int = 768, |
119 | num_buckets: int = 0, | 119 | num_buckets: int = 0, |
120 | bucket_step_size: int = 64, | 120 | bucket_step_size: int = 64, |
121 | max_pixels_per_bucket: Optional[int] = None, | 121 | bucket_max_pixels: Optional[int] = None, |
122 | progressive_buckets: bool = False, | 122 | progressive_buckets: bool = False, |
123 | dropout: float = 0, | 123 | dropout: float = 0, |
124 | interpolation: str = "bicubic", | 124 | interpolation: str = "bicubic", |
@@ -145,7 +145,7 @@ class VlpnDataModule(): | |||
145 | self.size = size | 145 | self.size = size |
146 | self.num_buckets = num_buckets | 146 | self.num_buckets = num_buckets |
147 | self.bucket_step_size = bucket_step_size | 147 | self.bucket_step_size = bucket_step_size |
148 | self.max_pixels_per_bucket = max_pixels_per_bucket | 148 | self.bucket_max_pixels = bucket_max_pixels |
149 | self.progressive_buckets = progressive_buckets | 149 | self.progressive_buckets = progressive_buckets |
150 | self.dropout = dropout | 150 | self.dropout = dropout |
151 | self.template_key = template_key | 151 | self.template_key = template_key |
@@ -234,7 +234,7 @@ class VlpnDataModule(): | |||
234 | train_dataset = VlpnDataset( | 234 | train_dataset = VlpnDataset( |
235 | self.data_train, self.prompt_processor, | 235 | self.data_train, self.prompt_processor, |
236 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 236 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
237 | bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, | 237 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
238 | batch_size=self.batch_size, generator=generator, | 238 | batch_size=self.batch_size, generator=generator, |
239 | size=self.size, interpolation=self.interpolation, | 239 | size=self.size, interpolation=self.interpolation, |
240 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 240 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
@@ -264,7 +264,7 @@ class VlpnDataset(IterableDataset): | |||
264 | prompt_processor: PromptProcessor, | 264 | prompt_processor: PromptProcessor, |
265 | num_buckets: int = 1, | 265 | num_buckets: int = 1, |
266 | bucket_step_size: int = 64, | 266 | bucket_step_size: int = 64, |
267 | max_pixels_per_bucket: Optional[int] = None, | 267 | bucket_max_pixels: Optional[int] = None, |
268 | progressive_buckets: bool = False, | 268 | progressive_buckets: bool = False, |
269 | batch_size: int = 1, | 269 | batch_size: int = 1, |
270 | num_class_images: int = 0, | 270 | num_class_images: int = 0, |
@@ -290,7 +290,7 @@ class VlpnDataset(IterableDataset): | |||
290 | base_size=size, | 290 | base_size=size, |
291 | step_size=bucket_step_size, | 291 | step_size=bucket_step_size, |
292 | num_buckets=num_buckets, | 292 | num_buckets=num_buckets, |
293 | max_pixels=max_pixels_per_bucket, | 293 | max_pixels=bucket_max_pixels, |
294 | progressive_buckets=progressive_buckets, | 294 | progressive_buckets=progressive_buckets, |
295 | ) | 295 | ) |
296 | 296 | ||
diff --git a/train_ti.py b/train_ti.py index eb0b8b6..03f52c4 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -142,7 +142,7 @@ def parse_args(): | |||
142 | parser.add_argument( | 142 | parser.add_argument( |
143 | "--num_buckets", | 143 | "--num_buckets", |
144 | type=int, | 144 | type=int, |
145 | default=4, | 145 | default=0, |
146 | help="Number of aspect ratio buckets in either direction.", | 146 | help="Number of aspect ratio buckets in either direction.", |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |