summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-09 10:57:05 +0100
committerVolpeon <git@volpeon.ink>2023-01-09 10:57:05 +0100
commit7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb (patch)
tree3ad7621f9ebf3468ff46aee4b6a736fa52f7aaf4
parentAdd --valid_set_repeat (diff)
downloadtextual-inversion-diff-7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb.tar.gz
textual-inversion-diff-7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb.tar.bz2
textual-inversion-diff-7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb.zip
Enable buckets for validation, fixed vaildation repeat arg
-rw-r--r--data/csv.py5
-rw-r--r--train_dreambooth.py5
-rw-r--r--train_ti.py5
3 files changed, 5 insertions, 10 deletions
diff --git a/data/csv.py b/data/csv.py
index 584a40c..ed8e93d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -245,6 +245,8 @@ class VlpnDataModule():
245 245
246 val_dataset = VlpnDataset( 246 val_dataset = VlpnDataset(
247 self.data_val, self.prompt_processor, 247 self.data_val, self.prompt_processor,
248 num_buckets=self.num_buckets, progressive_buckets=True,
249 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
248 repeat=self.valid_set_repeat, 250 repeat=self.valid_set_repeat,
249 batch_size=self.batch_size, generator=generator, 251 batch_size=self.batch_size, generator=generator,
250 size=self.size, interpolation=self.interpolation, 252 size=self.size, interpolation=self.interpolation,
@@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset):
291 self.generator = generator 293 self.generator = generator
292 294
293 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( 295 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
294 [item.instance_image_path for item in items], 296 [item.instance_image_path for item in self.items],
295 base_size=size, 297 base_size=size,
296 step_size=bucket_step_size, 298 step_size=bucket_step_size,
297 num_buckets=num_buckets, 299 num_buckets=num_buckets,
@@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset):
301 303
302 self.bucket_item_range = torch.arange(len(self.bucket_items)) 304 self.bucket_item_range = torch.arange(len(self.bucket_items))
303 305
304 self.cache = {}
305 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 306 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
306 307
307 def __len__(self): 308 def __len__(self):
diff --git a/train_dreambooth.py b/train_dreambooth.py
index aa5ff01..1a1f516 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -386,7 +386,7 @@ def parse_args():
386 parser.add_argument( 386 parser.add_argument(
387 "--valid_set_repeat", 387 "--valid_set_repeat",
388 type=int, 388 type=int,
389 default=None, 389 default=1,
390 help="Times the images in the validation dataset are repeated." 390 help="Times the images in the validation dataset are repeated."
391 ) 391 )
392 parser.add_argument( 392 parser.add_argument(
@@ -457,9 +457,6 @@ def parse_args():
457 if isinstance(args.exclude_collections, str): 457 if isinstance(args.exclude_collections, str):
458 args.exclude_collections = [args.exclude_collections] 458 args.exclude_collections = [args.exclude_collections]
459 459
460 if args.valid_set_repeat is None:
461 args.valid_set_repeat = args.train_batch_size
462
463 if args.output_dir is None: 460 if args.output_dir is None:
464 raise ValueError("You must specify --output_dir") 461 raise ValueError("You must specify --output_dir")
465 462
diff --git a/train_ti.py b/train_ti.py
index 7784d04..df8d443 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -383,7 +383,7 @@ def parse_args():
383 parser.add_argument( 383 parser.add_argument(
384 "--valid_set_repeat", 384 "--valid_set_repeat",
385 type=int, 385 type=int,
386 default=None, 386 default=1,
387 help="Times the images in the validation dataset are repeated." 387 help="Times the images in the validation dataset are repeated."
388 ) 388 )
389 parser.add_argument( 389 parser.add_argument(
@@ -477,9 +477,6 @@ def parse_args():
477 if isinstance(args.exclude_collections, str): 477 if isinstance(args.exclude_collections, str):
478 args.exclude_collections = [args.exclude_collections] 478 args.exclude_collections = [args.exclude_collections]
479 479
480 if args.valid_set_repeat is None:
481 args.valid_set_repeat = args.train_batch_size
482
483 if args.output_dir is None: 480 if args.output_dir is None:
484 raise ValueError("You must specify --output_dir") 481 raise ValueError("You must specify --output_dir")
485 482