diff options
author | Volpeon <git@volpeon.ink> | 2022-10-19 17:45:56 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-19 17:45:56 +0200 |
commit | 5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa (patch) | |
tree | fdf3c57c2f238e5381ad5e961e704fedd3e6816a /dreambooth.py | |
parent | Updated Dreambooth training (diff) | |
download | textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.tar.gz textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.tar.bz2 textual-inversion-diff-5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa.zip |
Dreambooth: Added option to insert a new input token; removed Dreambooth Plus
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 48 |
1 files changed, 42 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index d1cf535..da8399f 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -69,6 +69,18 @@ def parse_args(): | |||
69 | help="A token to use as a placeholder for the concept.", | 69 | help="A token to use as a placeholder for the concept.", |
70 | ) | 70 | ) |
71 | parser.add_argument( | 71 | parser.add_argument( |
72 | "--placeholder_token", | ||
73 | type=str, | ||
74 | default="<*>", | ||
75 | help="A token to use as a placeholder for the concept.", | ||
76 | ) | ||
77 | parser.add_argument( | ||
78 | "--initializer_token", | ||
79 | type=str, | ||
80 | default=None, | ||
81 | help="A token to use as initializer word." | ||
82 | ) | ||
83 | parser.add_argument( | ||
72 | "--num_class_images", | 84 | "--num_class_images", |
73 | type=int, | 85 | type=int, |
74 | default=400, | 86 | default=400, |
@@ -114,7 +126,7 @@ def parse_args(): | |||
114 | parser.add_argument( | 126 | parser.add_argument( |
115 | "--max_train_steps", | 127 | "--max_train_steps", |
116 | type=int, | 128 | type=int, |
117 | default=3600, | 129 | default=6000, |
118 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 130 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
119 | ) | 131 | ) |
120 | parser.add_argument( | 132 | parser.add_argument( |
@@ -131,13 +143,13 @@ def parse_args(): | |||
131 | parser.add_argument( | 143 | parser.add_argument( |
132 | "--learning_rate_unet", | 144 | "--learning_rate_unet", |
133 | type=float, | 145 | type=float, |
134 | default=3e-6, | 146 | default=2e-6, |
135 | help="Initial learning rate (after the potential warmup period) to use.", | 147 | help="Initial learning rate (after the potential warmup period) to use.", |
136 | ) | 148 | ) |
137 | parser.add_argument( | 149 | parser.add_argument( |
138 | "--learning_rate_text", | 150 | "--learning_rate_text", |
139 | type=float, | 151 | type=float, |
140 | default=3e-6, | 152 | default=2e-6, |
141 | help="Initial learning rate (after the potential warmup period) to use.", | 153 | help="Initial learning rate (after the potential warmup period) to use.", |
142 | ) | 154 | ) |
143 | parser.add_argument( | 155 | parser.add_argument( |
@@ -476,8 +488,13 @@ class Checkpointer: | |||
476 | def main(): | 488 | def main(): |
477 | args = parse_args() | 489 | args = parse_args() |
478 | 490 | ||
491 | instance_identifier = args.instance_identifier | ||
492 | |||
493 | if args.placeholder_token is not None: | ||
494 | instance_identifier = instance_identifier.format(args.placeholder_token) | ||
495 | |||
479 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 496 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
480 | basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) | 497 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
481 | basepath.mkdir(parents=True, exist_ok=True) | 498 | basepath.mkdir(parents=True, exist_ok=True) |
482 | 499 | ||
483 | accelerator = Accelerator( | 500 | accelerator = Accelerator( |
@@ -514,6 +531,25 @@ def main(): | |||
514 | device=accelerator.device | 531 | device=accelerator.device |
515 | ) if args.use_ema else None | 532 | ) if args.use_ema else None |
516 | 533 | ||
534 | if args.initializer_token is not None: | ||
535 | # Convert the initializer_token, placeholder_token to ids | ||
536 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
537 | print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") | ||
538 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | ||
539 | |||
540 | # Add the placeholder token in tokenizer | ||
541 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
542 | if num_added_tokens == 0: | ||
543 | print(f"Re-using existing token {args.placeholder_token}.") | ||
544 | else: | ||
545 | print(f"Training new token {args.placeholder_token}.") | ||
546 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
547 | |||
548 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
549 | token_embeds = text_encoder.get_input_embeddings() | ||
550 | initializer_token_embeddings = token_embeds(initializer_token_ids) | ||
551 | token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings | ||
552 | |||
517 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 553 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
518 | 554 | ||
519 | if args.gradient_checkpointing: | 555 | if args.gradient_checkpointing: |
@@ -605,7 +641,7 @@ def main(): | |||
605 | data_file=args.train_data_file, | 641 | data_file=args.train_data_file, |
606 | batch_size=args.train_batch_size, | 642 | batch_size=args.train_batch_size, |
607 | prompt_processor=prompt_processor, | 643 | prompt_processor=prompt_processor, |
608 | instance_identifier=args.instance_identifier, | 644 | instance_identifier=instance_identifier, |
609 | class_identifier=args.class_identifier, | 645 | class_identifier=args.class_identifier, |
610 | class_subdir="cls", | 646 | class_subdir="cls", |
611 | num_class_images=args.num_class_images, | 647 | num_class_images=args.num_class_images, |
@@ -735,7 +771,7 @@ def main(): | |||
735 | tokenizer=tokenizer, | 771 | tokenizer=tokenizer, |
736 | text_encoder=text_encoder, | 772 | text_encoder=text_encoder, |
737 | output_dir=basepath, | 773 | output_dir=basepath, |
738 | instance_identifier=args.instance_identifier, | 774 | instance_identifier=instance_identifier, |
739 | sample_image_size=args.sample_image_size, | 775 | sample_image_size=args.sample_image_size, |
740 | sample_batch_size=args.sample_batch_size, | 776 | sample_batch_size=args.sample_batch_size, |
741 | sample_batches=args.sample_batches, | 777 | sample_batches=args.sample_batches, |