diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 46 |
1 files changed, 1 insertions, 45 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a70c80e..5a4c47b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -74,26 +74,6 @@ def parse_args(): | |||
74 | help="The name of the current project.", | 74 | help="The name of the current project.", |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--placeholder_tokens", | ||
78 | type=str, | ||
79 | nargs='*', | ||
80 | default=[], | ||
81 | help="A token to use as a placeholder for the concept.", | ||
82 | ) | ||
83 | parser.add_argument( | ||
84 | "--initializer_tokens", | ||
85 | type=str, | ||
86 | nargs='*', | ||
87 | default=[], | ||
88 | help="A token to use as initializer word." | ||
89 | ) | ||
90 | parser.add_argument( | ||
91 | "--num_vectors", | ||
92 | type=int, | ||
93 | nargs='*', | ||
94 | help="Number of vectors per embedding." | ||
95 | ) | ||
96 | parser.add_argument( | ||
97 | "--exclude_collections", | 77 | "--exclude_collections", |
98 | type=str, | 78 | type=str, |
99 | nargs='*', | 79 | nargs='*', |
@@ -436,30 +416,6 @@ def parse_args(): | |||
436 | if args.project is None: | 416 | if args.project is None: |
437 | raise ValueError("You must specify --project") | 417 | raise ValueError("You must specify --project") |
438 | 418 | ||
439 | if isinstance(args.placeholder_tokens, str): | ||
440 | args.placeholder_tokens = [args.placeholder_tokens] | ||
441 | |||
442 | if isinstance(args.initializer_tokens, str): | ||
443 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
444 | |||
445 | if len(args.initializer_tokens) == 0: | ||
446 | raise ValueError("You must specify --initializer_tokens") | ||
447 | |||
448 | if len(args.placeholder_tokens) == 0: | ||
449 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
450 | |||
451 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
452 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
453 | |||
454 | if args.num_vectors is None: | ||
455 | args.num_vectors = 1 | ||
456 | |||
457 | if isinstance(args.num_vectors, int): | ||
458 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | ||
459 | |||
460 | if len(args.placeholder_tokens) != len(args.num_vectors): | ||
461 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
462 | |||
463 | if isinstance(args.collection, str): | 419 | if isinstance(args.collection, str): |
464 | args.collection = [args.collection] | 420 | args.collection = [args.collection] |
465 | 421 | ||
@@ -503,7 +459,7 @@ def main(): | |||
503 | 459 | ||
504 | vae.enable_slicing() | 460 | vae.enable_slicing() |
505 | vae.set_use_memory_efficient_attention_xformers(True) | 461 | vae.set_use_memory_efficient_attention_xformers(True) |
506 | unet.set_use_memory_efficient_attention_xformers(True) | 462 | unet.enable_xformers_memory_efficient_attention() |
507 | 463 | ||
508 | if args.gradient_checkpointing: | 464 | if args.gradient_checkpointing: |
509 | unet.enable_gradient_checkpointing() | 465 | unet.enable_gradient_checkpointing() |