diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 46 |
1 files changed, 1 insertions, 45 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a70c80e..5a4c47b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -74,26 +74,6 @@ def parse_args(): | |||
| 74 | help="The name of the current project.", | 74 | help="The name of the current project.", |
| 75 | ) | 75 | ) |
| 76 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--placeholder_tokens", | ||
| 78 | type=str, | ||
| 79 | nargs='*', | ||
| 80 | default=[], | ||
| 81 | help="A token to use as a placeholder for the concept.", | ||
| 82 | ) | ||
| 83 | parser.add_argument( | ||
| 84 | "--initializer_tokens", | ||
| 85 | type=str, | ||
| 86 | nargs='*', | ||
| 87 | default=[], | ||
| 88 | help="A token to use as initializer word." | ||
| 89 | ) | ||
| 90 | parser.add_argument( | ||
| 91 | "--num_vectors", | ||
| 92 | type=int, | ||
| 93 | nargs='*', | ||
| 94 | help="Number of vectors per embedding." | ||
| 95 | ) | ||
| 96 | parser.add_argument( | ||
| 97 | "--exclude_collections", | 77 | "--exclude_collections", |
| 98 | type=str, | 78 | type=str, |
| 99 | nargs='*', | 79 | nargs='*', |
| @@ -436,30 +416,6 @@ def parse_args(): | |||
| 436 | if args.project is None: | 416 | if args.project is None: |
| 437 | raise ValueError("You must specify --project") | 417 | raise ValueError("You must specify --project") |
| 438 | 418 | ||
| 439 | if isinstance(args.placeholder_tokens, str): | ||
| 440 | args.placeholder_tokens = [args.placeholder_tokens] | ||
| 441 | |||
| 442 | if isinstance(args.initializer_tokens, str): | ||
| 443 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
| 444 | |||
| 445 | if len(args.initializer_tokens) == 0: | ||
| 446 | raise ValueError("You must specify --initializer_tokens") | ||
| 447 | |||
| 448 | if len(args.placeholder_tokens) == 0: | ||
| 449 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
| 450 | |||
| 451 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
| 452 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
| 453 | |||
| 454 | if args.num_vectors is None: | ||
| 455 | args.num_vectors = 1 | ||
| 456 | |||
| 457 | if isinstance(args.num_vectors, int): | ||
| 458 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | ||
| 459 | |||
| 460 | if len(args.placeholder_tokens) != len(args.num_vectors): | ||
| 461 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
| 462 | |||
| 463 | if isinstance(args.collection, str): | 419 | if isinstance(args.collection, str): |
| 464 | args.collection = [args.collection] | 420 | args.collection = [args.collection] |
| 465 | 421 | ||
| @@ -503,7 +459,7 @@ def main(): | |||
| 503 | 459 | ||
| 504 | vae.enable_slicing() | 460 | vae.enable_slicing() |
| 505 | vae.set_use_memory_efficient_attention_xformers(True) | 461 | vae.set_use_memory_efficient_attention_xformers(True) |
| 506 | unet.set_use_memory_efficient_attention_xformers(True) | 462 | unet.enable_xformers_memory_efficient_attention() |
| 507 | 463 | ||
| 508 | if args.gradient_checkpointing: | 464 | if args.gradient_checkpointing: |
| 509 | unet.enable_gradient_checkpointing() | 465 | unet.enable_gradient_checkpointing() |
