diff options
-rw-r--r-- | data/csv.py | 5 | ||||
-rw-r--r-- | train_dreambooth.py | 5 | ||||
-rw-r--r-- | train_ti.py | 5 |
3 files changed, 5 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index 584a40c..ed8e93d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -245,6 +245,8 @@ class VlpnDataModule(): | |||
245 | 245 | ||
246 | val_dataset = VlpnDataset( | 246 | val_dataset = VlpnDataset( |
247 | self.data_val, self.prompt_processor, | 247 | self.data_val, self.prompt_processor, |
248 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
249 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
248 | repeat=self.valid_set_repeat, | 250 | repeat=self.valid_set_repeat, |
249 | batch_size=self.batch_size, generator=generator, | 251 | batch_size=self.batch_size, generator=generator, |
250 | size=self.size, interpolation=self.interpolation, | 252 | size=self.size, interpolation=self.interpolation, |
@@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset): | |||
291 | self.generator = generator | 293 | self.generator = generator |
292 | 294 | ||
293 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 295 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
294 | [item.instance_image_path for item in items], | 296 | [item.instance_image_path for item in self.items], |
295 | base_size=size, | 297 | base_size=size, |
296 | step_size=bucket_step_size, | 298 | step_size=bucket_step_size, |
297 | num_buckets=num_buckets, | 299 | num_buckets=num_buckets, |
@@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset): | |||
301 | 303 | ||
302 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 304 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
303 | 305 | ||
304 | self.cache = {} | ||
305 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 306 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
306 | 307 | ||
307 | def __len__(self): | 308 | def __len__(self): |
diff --git a/train_dreambooth.py b/train_dreambooth.py index aa5ff01..1a1f516 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -386,7 +386,7 @@ def parse_args(): | |||
386 | parser.add_argument( | 386 | parser.add_argument( |
387 | "--valid_set_repeat", | 387 | "--valid_set_repeat", |
388 | type=int, | 388 | type=int, |
389 | default=None, | 389 | default=1, |
390 | help="Times the images in the validation dataset are repeated." | 390 | help="Times the images in the validation dataset are repeated." |
391 | ) | 391 | ) |
392 | parser.add_argument( | 392 | parser.add_argument( |
@@ -457,9 +457,6 @@ def parse_args(): | |||
457 | if isinstance(args.exclude_collections, str): | 457 | if isinstance(args.exclude_collections, str): |
458 | args.exclude_collections = [args.exclude_collections] | 458 | args.exclude_collections = [args.exclude_collections] |
459 | 459 | ||
460 | if args.valid_set_repeat is None: | ||
461 | args.valid_set_repeat = args.train_batch_size | ||
462 | |||
463 | if args.output_dir is None: | 460 | if args.output_dir is None: |
464 | raise ValueError("You must specify --output_dir") | 461 | raise ValueError("You must specify --output_dir") |
465 | 462 | ||
diff --git a/train_ti.py b/train_ti.py index 7784d04..df8d443 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -383,7 +383,7 @@ def parse_args(): | |||
383 | parser.add_argument( | 383 | parser.add_argument( |
384 | "--valid_set_repeat", | 384 | "--valid_set_repeat", |
385 | type=int, | 385 | type=int, |
386 | default=None, | 386 | default=1, |
387 | help="Times the images in the validation dataset are repeated." | 387 | help="Times the images in the validation dataset are repeated." |
388 | ) | 388 | ) |
389 | parser.add_argument( | 389 | parser.add_argument( |
@@ -477,9 +477,6 @@ def parse_args(): | |||
477 | if isinstance(args.exclude_collections, str): | 477 | if isinstance(args.exclude_collections, str): |
478 | args.exclude_collections = [args.exclude_collections] | 478 | args.exclude_collections = [args.exclude_collections] |
479 | 479 | ||
480 | if args.valid_set_repeat is None: | ||
481 | args.valid_set_repeat = args.train_batch_size | ||
482 | |||
483 | if args.output_dir is None: | 480 | if args.output_dir is None: |
484 | raise ValueError("You must specify --output_dir") | 481 | raise ValueError("You must specify --output_dir") |
485 | 482 | ||