summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py6
-rw-r--r--train_dreambooth.py10
-rw-r--r--train_ti.py22
3 files changed, 37 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index 2f0a392..584a40c 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -125,6 +125,7 @@ class VlpnDataModule():
125 interpolation: str = "bicubic", 125 interpolation: str = "bicubic",
126 template_key: str = "template", 126 template_key: str = "template",
127 valid_set_size: Optional[int] = None, 127 valid_set_size: Optional[int] = None,
128 valid_set_repeat: int = 1,
128 seed: Optional[int] = None, 129 seed: Optional[int] = None,
129 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 130 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
130 collate_fn=None, 131 collate_fn=None,
@@ -152,6 +153,7 @@ class VlpnDataModule():
152 self.template_key = template_key 153 self.template_key = template_key
153 self.interpolation = interpolation 154 self.interpolation = interpolation
154 self.valid_set_size = valid_set_size 155 self.valid_set_size = valid_set_size
156 self.valid_set_repeat = valid_set_repeat
155 self.seed = seed 157 self.seed = seed
156 self.filter = filter 158 self.filter = filter
157 self.collate_fn = collate_fn 159 self.collate_fn = collate_fn
@@ -243,6 +245,7 @@ class VlpnDataModule():
243 245
244 val_dataset = VlpnDataset( 246 val_dataset = VlpnDataset(
245 self.data_val, self.prompt_processor, 247 self.data_val, self.prompt_processor,
248 repeat=self.valid_set_repeat,
246 batch_size=self.batch_size, generator=generator, 249 batch_size=self.batch_size, generator=generator,
247 size=self.size, interpolation=self.interpolation, 250 size=self.size, interpolation=self.interpolation,
248 ) 251 )
@@ -267,6 +270,7 @@ class VlpnDataset(IterableDataset):
267 bucket_step_size: int = 64, 270 bucket_step_size: int = 64,
268 bucket_max_pixels: Optional[int] = None, 271 bucket_max_pixels: Optional[int] = None,
269 progressive_buckets: bool = False, 272 progressive_buckets: bool = False,
273 repeat: int = 1,
270 batch_size: int = 1, 274 batch_size: int = 1,
271 num_class_images: int = 0, 275 num_class_images: int = 0,
272 size: int = 768, 276 size: int = 768,
@@ -275,7 +279,7 @@ class VlpnDataset(IterableDataset):
275 interpolation: str = "bicubic", 279 interpolation: str = "bicubic",
276 generator: Optional[torch.Generator] = None, 280 generator: Optional[torch.Generator] = None,
277 ): 281 ):
278 self.items = items 282 self.items = items * repeat
279 self.batch_size = batch_size 283 self.batch_size = batch_size
280 284
281 self.prompt_processor = prompt_processor 285 self.prompt_processor = prompt_processor
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d396249..aa5ff01 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -384,6 +384,12 @@ def parse_args():
384 help="Number of images in the validation dataset." 384 help="Number of images in the validation dataset."
385 ) 385 )
386 parser.add_argument( 386 parser.add_argument(
387 "--valid_set_repeat",
388 type=int,
389 default=None,
390 help="Times the images in the validation dataset are repeated."
391 )
392 parser.add_argument(
387 "--train_batch_size", 393 "--train_batch_size",
388 type=int, 394 type=int,
389 default=1, 395 default=1,
@@ -451,6 +457,9 @@ def parse_args():
451 if isinstance(args.exclude_collections, str): 457 if isinstance(args.exclude_collections, str):
452 args.exclude_collections = [args.exclude_collections] 458 args.exclude_collections = [args.exclude_collections]
453 459
460 if args.valid_set_repeat is None:
461 args.valid_set_repeat = args.train_batch_size
462
454 if args.output_dir is None: 463 if args.output_dir is None:
455 raise ValueError("You must specify --output_dir") 464 raise ValueError("You must specify --output_dir")
456 465
@@ -764,6 +773,7 @@ def main():
764 dropout=args.tag_dropout, 773 dropout=args.tag_dropout,
765 template_key=args.train_data_template, 774 template_key=args.train_data_template,
766 valid_set_size=args.valid_set_size, 775 valid_set_size=args.valid_set_size,
776 valid_set_repeat=args.valid_set_repeat,
767 num_workers=args.dataloader_num_workers, 777 num_workers=args.dataloader_num_workers,
768 seed=args.seed, 778 seed=args.seed,
769 filter=keyword_filter, 779 filter=keyword_filter,
diff --git a/train_ti.py b/train_ti.py
index 03f52c4..7784d04 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -381,6 +381,12 @@ def parse_args():
381 help="Number of images in the validation dataset." 381 help="Number of images in the validation dataset."
382 ) 382 )
383 parser.add_argument( 383 parser.add_argument(
384 "--valid_set_repeat",
385 type=int,
386 default=None,
387 help="Times the images in the validation dataset are repeated."
388 )
389 parser.add_argument(
384 "--train_batch_size", 390 "--train_batch_size",
385 type=int, 391 type=int,
386 default=1, 392 default=1,
@@ -399,6 +405,12 @@ def parse_args():
399 help="The weight of prior preservation loss." 405 help="The weight of prior preservation loss."
400 ) 406 )
401 parser.add_argument( 407 parser.add_argument(
408 "--max_grad_norm",
409 default=3.0,
410 type=float,
411 help="Max gradient norm."
412 )
413 parser.add_argument(
402 "--noise_timesteps", 414 "--noise_timesteps",
403 type=int, 415 type=int,
404 default=1000, 416 default=1000,
@@ -465,6 +477,9 @@ def parse_args():
465 if isinstance(args.exclude_collections, str): 477 if isinstance(args.exclude_collections, str):
466 args.exclude_collections = [args.exclude_collections] 478 args.exclude_collections = [args.exclude_collections]
467 479
480 if args.valid_set_repeat is None:
481 args.valid_set_repeat = args.train_batch_size
482
468 if args.output_dir is None: 483 if args.output_dir is None:
469 raise ValueError("You must specify --output_dir") 484 raise ValueError("You must specify --output_dir")
470 485
@@ -735,6 +750,7 @@ def main():
735 dropout=args.tag_dropout, 750 dropout=args.tag_dropout,
736 template_key=args.train_data_template, 751 template_key=args.train_data_template,
737 valid_set_size=args.valid_set_size, 752 valid_set_size=args.valid_set_size,
753 valid_set_repeat=args.valid_set_repeat,
738 num_workers=args.dataloader_num_workers, 754 num_workers=args.dataloader_num_workers,
739 seed=args.seed, 755 seed=args.seed,
740 filter=keyword_filter, 756 filter=keyword_filter,
@@ -961,6 +977,12 @@ def main():
961 977
962 accelerator.backward(loss) 978 accelerator.backward(loss)
963 979
980 if accelerator.sync_gradients:
981 accelerator.clip_grad_norm_(
982 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
983 args.max_grad_norm
984 )
985
964 optimizer.step() 986 optimizer.step()
965 if not accelerator.optimizer_step_was_skipped: 987 if not accelerator.optimizer_step_was_skipped:
966 lr_scheduler.step() 988 lr_scheduler.step()