From b57ca669a150d9313447612fb8c37668f4f2a80d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 9 Jan 2023 10:19:37 +0100 Subject: Add --valid_set_repeat --- data/csv.py | 6 +++++- train_dreambooth.py | 10 ++++++++++ train_ti.py | 22 ++++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/data/csv.py b/data/csv.py index 2f0a392..584a40c 100644 --- a/data/csv.py +++ b/data/csv.py @@ -125,6 +125,7 @@ class VlpnDataModule(): interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, + valid_set_repeat: int = 1, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, collate_fn=None, @@ -152,6 +153,7 @@ class VlpnDataModule(): self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size + self.valid_set_repeat = valid_set_repeat self.seed = seed self.filter = filter self.collate_fn = collate_fn @@ -243,6 +245,7 @@ class VlpnDataModule(): val_dataset = VlpnDataset( self.data_val, self.prompt_processor, + repeat=self.valid_set_repeat, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) @@ -267,6 +270,7 @@ class VlpnDataset(IterableDataset): bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, + repeat: int = 1, batch_size: int = 1, num_class_images: int = 0, size: int = 768, @@ -275,7 +279,7 @@ class VlpnDataset(IterableDataset): interpolation: str = "bicubic", generator: Optional[torch.Generator] = None, ): - self.items = items + self.items = items * repeat self.batch_size = batch_size self.prompt_processor = prompt_processor diff --git a/train_dreambooth.py b/train_dreambooth.py index d396249..aa5ff01 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -383,6 +383,12 @@ def parse_args(): default=None, help="Number of images in the validation dataset." ) + parser.add_argument( + "--valid_set_repeat", + type=int, + default=None, + help="Times the images in the validation dataset are repeated." + ) parser.add_argument( "--train_batch_size", type=int, @@ -451,6 +457,9 @@ def parse_args(): if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] + if args.valid_set_repeat is None: + args.valid_set_repeat = args.train_batch_size + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -764,6 +773,7 @@ def main(): dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, + valid_set_repeat=args.valid_set_repeat, num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, diff --git a/train_ti.py b/train_ti.py index 03f52c4..7784d04 100644 --- a/train_ti.py +++ b/train_ti.py @@ -380,6 +380,12 @@ def parse_args(): default=None, help="Number of images in the validation dataset." ) + parser.add_argument( + "--valid_set_repeat", + type=int, + default=None, + help="Times the images in the validation dataset are repeated." + ) parser.add_argument( "--train_batch_size", type=int, @@ -398,6 +404,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--max_grad_norm", + default=3.0, + type=float, + help="Max gradient norm." + ) parser.add_argument( "--noise_timesteps", type=int, @@ -465,6 +477,9 @@ def parse_args(): if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] + if args.valid_set_repeat is None: + args.valid_set_repeat = args.train_batch_size + if args.output_dir is None: raise ValueError("You must specify --output_dir") @@ -735,6 +750,7 @@ def main(): dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, + valid_set_repeat=args.valid_set_repeat, num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, @@ -961,6 +977,12 @@ def main(): accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + args.max_grad_norm + ) + optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() -- cgit v1.2.3-54-g00ecf