diff options
author | Volpeon <git@volpeon.ink> | 2023-01-09 10:19:37 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-09 10:19:37 +0100 |
commit | b57ca669a150d9313447612fb8c37668f4f2a80d (patch) | |
tree | b0ebfedc33c26847838850416b96fd2623cf6ba5 | |
parent | No cache after all (diff) | |
download | textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.gz textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.bz2 textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.zip |
Add --valid_set_repeat
-rw-r--r-- | data/csv.py | 6 | ||||
-rw-r--r-- | train_dreambooth.py | 10 | ||||
-rw-r--r-- | train_ti.py | 22 |
3 files changed, 37 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py index 2f0a392..584a40c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -125,6 +125,7 @@ class VlpnDataModule(): | |||
125 | interpolation: str = "bicubic", | 125 | interpolation: str = "bicubic", |
126 | template_key: str = "template", | 126 | template_key: str = "template", |
127 | valid_set_size: Optional[int] = None, | 127 | valid_set_size: Optional[int] = None, |
128 | valid_set_repeat: int = 1, | ||
128 | seed: Optional[int] = None, | 129 | seed: Optional[int] = None, |
129 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 130 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
130 | collate_fn=None, | 131 | collate_fn=None, |
@@ -152,6 +153,7 @@ class VlpnDataModule(): | |||
152 | self.template_key = template_key | 153 | self.template_key = template_key |
153 | self.interpolation = interpolation | 154 | self.interpolation = interpolation |
154 | self.valid_set_size = valid_set_size | 155 | self.valid_set_size = valid_set_size |
156 | self.valid_set_repeat = valid_set_repeat | ||
155 | self.seed = seed | 157 | self.seed = seed |
156 | self.filter = filter | 158 | self.filter = filter |
157 | self.collate_fn = collate_fn | 159 | self.collate_fn = collate_fn |
@@ -243,6 +245,7 @@ class VlpnDataModule(): | |||
243 | 245 | ||
244 | val_dataset = VlpnDataset( | 246 | val_dataset = VlpnDataset( |
245 | self.data_val, self.prompt_processor, | 247 | self.data_val, self.prompt_processor, |
248 | repeat=self.valid_set_repeat, | ||
246 | batch_size=self.batch_size, generator=generator, | 249 | batch_size=self.batch_size, generator=generator, |
247 | size=self.size, interpolation=self.interpolation, | 250 | size=self.size, interpolation=self.interpolation, |
248 | ) | 251 | ) |
@@ -267,6 +270,7 @@ class VlpnDataset(IterableDataset): | |||
267 | bucket_step_size: int = 64, | 270 | bucket_step_size: int = 64, |
268 | bucket_max_pixels: Optional[int] = None, | 271 | bucket_max_pixels: Optional[int] = None, |
269 | progressive_buckets: bool = False, | 272 | progressive_buckets: bool = False, |
273 | repeat: int = 1, | ||
270 | batch_size: int = 1, | 274 | batch_size: int = 1, |
271 | num_class_images: int = 0, | 275 | num_class_images: int = 0, |
272 | size: int = 768, | 276 | size: int = 768, |
@@ -275,7 +279,7 @@ class VlpnDataset(IterableDataset): | |||
275 | interpolation: str = "bicubic", | 279 | interpolation: str = "bicubic", |
276 | generator: Optional[torch.Generator] = None, | 280 | generator: Optional[torch.Generator] = None, |
277 | ): | 281 | ): |
278 | self.items = items | 282 | self.items = items * repeat |
279 | self.batch_size = batch_size | 283 | self.batch_size = batch_size |
280 | 284 | ||
281 | self.prompt_processor = prompt_processor | 285 | self.prompt_processor = prompt_processor |
diff --git a/train_dreambooth.py b/train_dreambooth.py index d396249..aa5ff01 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -384,6 +384,12 @@ def parse_args(): | |||
384 | help="Number of images in the validation dataset." | 384 | help="Number of images in the validation dataset." |
385 | ) | 385 | ) |
386 | parser.add_argument( | 386 | parser.add_argument( |
387 | "--valid_set_repeat", | ||
388 | type=int, | ||
389 | default=None, | ||
390 | help="Times the images in the validation dataset are repeated." | ||
391 | ) | ||
392 | parser.add_argument( | ||
387 | "--train_batch_size", | 393 | "--train_batch_size", |
388 | type=int, | 394 | type=int, |
389 | default=1, | 395 | default=1, |
@@ -451,6 +457,9 @@ def parse_args(): | |||
451 | if isinstance(args.exclude_collections, str): | 457 | if isinstance(args.exclude_collections, str): |
452 | args.exclude_collections = [args.exclude_collections] | 458 | args.exclude_collections = [args.exclude_collections] |
453 | 459 | ||
460 | if args.valid_set_repeat is None: | ||
461 | args.valid_set_repeat = args.train_batch_size | ||
462 | |||
454 | if args.output_dir is None: | 463 | if args.output_dir is None: |
455 | raise ValueError("You must specify --output_dir") | 464 | raise ValueError("You must specify --output_dir") |
456 | 465 | ||
@@ -764,6 +773,7 @@ def main(): | |||
764 | dropout=args.tag_dropout, | 773 | dropout=args.tag_dropout, |
765 | template_key=args.train_data_template, | 774 | template_key=args.train_data_template, |
766 | valid_set_size=args.valid_set_size, | 775 | valid_set_size=args.valid_set_size, |
776 | valid_set_repeat=args.valid_set_repeat, | ||
767 | num_workers=args.dataloader_num_workers, | 777 | num_workers=args.dataloader_num_workers, |
768 | seed=args.seed, | 778 | seed=args.seed, |
769 | filter=keyword_filter, | 779 | filter=keyword_filter, |
diff --git a/train_ti.py b/train_ti.py index 03f52c4..7784d04 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -381,6 +381,12 @@ def parse_args(): | |||
381 | help="Number of images in the validation dataset." | 381 | help="Number of images in the validation dataset." |
382 | ) | 382 | ) |
383 | parser.add_argument( | 383 | parser.add_argument( |
384 | "--valid_set_repeat", | ||
385 | type=int, | ||
386 | default=None, | ||
387 | help="Times the images in the validation dataset are repeated." | ||
388 | ) | ||
389 | parser.add_argument( | ||
384 | "--train_batch_size", | 390 | "--train_batch_size", |
385 | type=int, | 391 | type=int, |
386 | default=1, | 392 | default=1, |
@@ -399,6 +405,12 @@ def parse_args(): | |||
399 | help="The weight of prior preservation loss." | 405 | help="The weight of prior preservation loss." |
400 | ) | 406 | ) |
401 | parser.add_argument( | 407 | parser.add_argument( |
408 | "--max_grad_norm", | ||
409 | default=3.0, | ||
410 | type=float, | ||
411 | help="Max gradient norm." | ||
412 | ) | ||
413 | parser.add_argument( | ||
402 | "--noise_timesteps", | 414 | "--noise_timesteps", |
403 | type=int, | 415 | type=int, |
404 | default=1000, | 416 | default=1000, |
@@ -465,6 +477,9 @@ def parse_args(): | |||
465 | if isinstance(args.exclude_collections, str): | 477 | if isinstance(args.exclude_collections, str): |
466 | args.exclude_collections = [args.exclude_collections] | 478 | args.exclude_collections = [args.exclude_collections] |
467 | 479 | ||
480 | if args.valid_set_repeat is None: | ||
481 | args.valid_set_repeat = args.train_batch_size | ||
482 | |||
468 | if args.output_dir is None: | 483 | if args.output_dir is None: |
469 | raise ValueError("You must specify --output_dir") | 484 | raise ValueError("You must specify --output_dir") |
470 | 485 | ||
@@ -735,6 +750,7 @@ def main(): | |||
735 | dropout=args.tag_dropout, | 750 | dropout=args.tag_dropout, |
736 | template_key=args.train_data_template, | 751 | template_key=args.train_data_template, |
737 | valid_set_size=args.valid_set_size, | 752 | valid_set_size=args.valid_set_size, |
753 | valid_set_repeat=args.valid_set_repeat, | ||
738 | num_workers=args.dataloader_num_workers, | 754 | num_workers=args.dataloader_num_workers, |
739 | seed=args.seed, | 755 | seed=args.seed, |
740 | filter=keyword_filter, | 756 | filter=keyword_filter, |
@@ -961,6 +977,12 @@ def main(): | |||
961 | 977 | ||
962 | accelerator.backward(loss) | 978 | accelerator.backward(loss) |
963 | 979 | ||
980 | if accelerator.sync_gradients: | ||
981 | accelerator.clip_grad_norm_( | ||
982 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
983 | args.max_grad_norm | ||
984 | ) | ||
985 | |||
964 | optimizer.step() | 986 | optimizer.step() |
965 | if not accelerator.optimizer_step_was_skipped: | 987 | if not accelerator.optimizer_step_was_skipped: |
966 | lr_scheduler.step() | 988 | lr_scheduler.step() |