From 7ce728b7ea9cfe6b6dc7d05826c1bf64eec5aacb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 9 Jan 2023 10:57:05 +0100 Subject: Enable buckets for validation, fixed vaildation repeat arg --- data/csv.py | 5 +++-- train_dreambooth.py | 5 +---- train_ti.py | 5 +---- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/data/csv.py b/data/csv.py index 584a40c..ed8e93d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -245,6 +245,8 @@ class VlpnDataModule(): val_dataset = VlpnDataset( self.data_val, self.prompt_processor, + num_buckets=self.num_buckets, progressive_buckets=True, + bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, repeat=self.valid_set_repeat, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, @@ -291,7 +293,7 @@ class VlpnDataset(IterableDataset): self.generator = generator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( - [item.instance_image_path for item in items], + [item.instance_image_path for item in self.items], base_size=size, step_size=bucket_step_size, num_buckets=num_buckets, @@ -301,7 +303,6 @@ class VlpnDataset(IterableDataset): self.bucket_item_range = torch.arange(len(self.bucket_items)) - self.cache = {} self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def __len__(self): diff --git a/train_dreambooth.py b/train_dreambooth.py index aa5ff01..1a1f516 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -386,7 +386,7 @@ def parse_args(): parser.add_argument( "--valid_set_repeat", type=int, - default=None, + default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( @@ -457,9 +457,6 @@ def parse_args(): if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] - if args.valid_set_repeat is None: - args.valid_set_repeat = args.train_batch_size - if args.output_dir is None: raise ValueError("You must specify --output_dir") diff --git a/train_ti.py b/train_ti.py index 7784d04..df8d443 100644 --- a/train_ti.py +++ b/train_ti.py @@ -383,7 +383,7 @@ def parse_args(): parser.add_argument( "--valid_set_repeat", type=int, - default=None, + default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( @@ -477,9 +477,6 @@ def parse_args(): if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] - if args.valid_set_repeat is None: - args.valid_set_repeat = args.train_batch_size - if args.output_dir is None: raise ValueError("You must specify --output_dir") -- cgit v1.2.3-70-g09d2