summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py46
1 files changed, 1 insertions, 45 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a70c80e..5a4c47b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -74,26 +74,6 @@ def parse_args():
74 help="The name of the current project.", 74 help="The name of the current project.",
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--placeholder_tokens",
78 type=str,
79 nargs='*',
80 default=[],
81 help="A token to use as a placeholder for the concept.",
82 )
83 parser.add_argument(
84 "--initializer_tokens",
85 type=str,
86 nargs='*',
87 default=[],
88 help="A token to use as initializer word."
89 )
90 parser.add_argument(
91 "--num_vectors",
92 type=int,
93 nargs='*',
94 help="Number of vectors per embedding."
95 )
96 parser.add_argument(
97 "--exclude_collections", 77 "--exclude_collections",
98 type=str, 78 type=str,
99 nargs='*', 79 nargs='*',
@@ -436,30 +416,6 @@ def parse_args():
436 if args.project is None: 416 if args.project is None:
437 raise ValueError("You must specify --project") 417 raise ValueError("You must specify --project")
438 418
439 if isinstance(args.placeholder_tokens, str):
440 args.placeholder_tokens = [args.placeholder_tokens]
441
442 if isinstance(args.initializer_tokens, str):
443 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
444
445 if len(args.initializer_tokens) == 0:
446 raise ValueError("You must specify --initializer_tokens")
447
448 if len(args.placeholder_tokens) == 0:
449 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
450
451 if len(args.placeholder_tokens) != len(args.initializer_tokens):
452 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
453
454 if args.num_vectors is None:
455 args.num_vectors = 1
456
457 if isinstance(args.num_vectors, int):
458 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
459
460 if len(args.placeholder_tokens) != len(args.num_vectors):
461 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
462
463 if isinstance(args.collection, str): 419 if isinstance(args.collection, str):
464 args.collection = [args.collection] 420 args.collection = [args.collection]
465 421
@@ -503,7 +459,7 @@ def main():
503 459
504 vae.enable_slicing() 460 vae.enable_slicing()
505 vae.set_use_memory_efficient_attention_xformers(True) 461 vae.set_use_memory_efficient_attention_xformers(True)
506 unet.set_use_memory_efficient_attention_xformers(True) 462 unet.enable_xformers_memory_efficient_attention()
507 463
508 if args.gradient_checkpointing: 464 if args.gradient_checkpointing:
509 unet.enable_gradient_checkpointing() 465 unet.enable_gradient_checkpointing()