diff options
| -rw-r--r-- | data/csv.py | 5 | ||||
| -rw-r--r-- | train_dreambooth.py | 5 | ||||
| -rw-r--r-- | train_ti.py | 5 |
3 files changed, 5 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py index 584a40c..ed8e93d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -245,6 +245,8 @@ class VlpnDataModule(): | |||
| 245 | 245 | ||
| 246 | val_dataset = VlpnDataset( | 246 | val_dataset = VlpnDataset( |
| 247 | self.data_val, self.prompt_processor, | 247 | self.data_val, self.prompt_processor, |
| 248 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
| 249 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
| 248 | repeat=self.valid_set_repeat, | 250 | repeat=self.valid_set_repeat, |
| 249 | batch_size=self.batch_size, generator=generator, | 251 | batch_size=self.batch_size, generator=generator, |
| 250 | size=self.size, interpolation=self.interpolation, | 252 | size=self.size, interpolation=self.interpolation, |
| @@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset): | |||
| 291 | self.generator = generator | 293 | self.generator = generator |
| 292 | 294 | ||
| 293 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 295 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
| 294 | [item.instance_image_path for item in items], | 296 | [item.instance_image_path for item in self.items], |
| 295 | base_size=size, | 297 | base_size=size, |
| 296 | step_size=bucket_step_size, | 298 | step_size=bucket_step_size, |
| 297 | num_buckets=num_buckets, | 299 | num_buckets=num_buckets, |
| @@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset): | |||
| 301 | 303 | ||
| 302 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 304 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
| 303 | 305 | ||
| 304 | self.cache = {} | ||
| 305 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 306 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() |
| 306 | 307 | ||
| 307 | def __len__(self): | 308 | def __len__(self): |
diff --git a/train_dreambooth.py b/train_dreambooth.py index aa5ff01..1a1f516 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -386,7 +386,7 @@ def parse_args(): | |||
| 386 | parser.add_argument( | 386 | parser.add_argument( |
| 387 | "--valid_set_repeat", | 387 | "--valid_set_repeat", |
| 388 | type=int, | 388 | type=int, |
| 389 | default=None, | 389 | default=1, |
| 390 | help="Times the images in the validation dataset are repeated." | 390 | help="Times the images in the validation dataset are repeated." |
| 391 | ) | 391 | ) |
| 392 | parser.add_argument( | 392 | parser.add_argument( |
| @@ -457,9 +457,6 @@ def parse_args(): | |||
| 457 | if isinstance(args.exclude_collections, str): | 457 | if isinstance(args.exclude_collections, str): |
| 458 | args.exclude_collections = [args.exclude_collections] | 458 | args.exclude_collections = [args.exclude_collections] |
| 459 | 459 | ||
| 460 | if args.valid_set_repeat is None: | ||
| 461 | args.valid_set_repeat = args.train_batch_size | ||
| 462 | |||
| 463 | if args.output_dir is None: | 460 | if args.output_dir is None: |
| 464 | raise ValueError("You must specify --output_dir") | 461 | raise ValueError("You must specify --output_dir") |
| 465 | 462 | ||
diff --git a/train_ti.py b/train_ti.py index 7784d04..df8d443 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -383,7 +383,7 @@ def parse_args(): | |||
| 383 | parser.add_argument( | 383 | parser.add_argument( |
| 384 | "--valid_set_repeat", | 384 | "--valid_set_repeat", |
| 385 | type=int, | 385 | type=int, |
| 386 | default=None, | 386 | default=1, |
| 387 | help="Times the images in the validation dataset are repeated." | 387 | help="Times the images in the validation dataset are repeated." |
| 388 | ) | 388 | ) |
| 389 | parser.add_argument( | 389 | parser.add_argument( |
| @@ -477,9 +477,6 @@ def parse_args(): | |||
| 477 | if isinstance(args.exclude_collections, str): | 477 | if isinstance(args.exclude_collections, str): |
| 478 | args.exclude_collections = [args.exclude_collections] | 478 | args.exclude_collections = [args.exclude_collections] |
| 479 | 479 | ||
| 480 | if args.valid_set_repeat is None: | ||
| 481 | args.valid_set_repeat = args.train_batch_size | ||
| 482 | |||
| 483 | if args.output_dir is None: | 480 | if args.output_dir is None: |
| 484 | raise ValueError("You must specify --output_dir") | 481 | raise ValueError("You must specify --output_dir") |
| 485 | 482 | ||
