diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index eeee424..42994af 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -118,7 +118,7 @@ def parse_args(): | |||
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| 119 | "--max_train_steps", | 119 | "--max_train_steps", |
| 120 | type=int, | 120 | type=int, |
| 121 | default=1300, | 121 | default=1200, |
| 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 123 | ) | 123 | ) |
| 124 | parser.add_argument( | 124 | parser.add_argument( |
| @@ -141,7 +141,7 @@ def parse_args(): | |||
| 141 | parser.add_argument( | 141 | parser.add_argument( |
| 142 | "--learning_rate_text", | 142 | "--learning_rate_text", |
| 143 | type=float, | 143 | type=float, |
| 144 | default=5e-6, | 144 | default=1e-6, |
| 145 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
| 146 | ) | 146 | ) |
| 147 | parser.add_argument( | 147 | parser.add_argument( |
| @@ -153,7 +153,7 @@ def parse_args(): | |||
| 153 | parser.add_argument( | 153 | parser.add_argument( |
| 154 | "--lr_scheduler", | 154 | "--lr_scheduler", |
| 155 | type=str, | 155 | type=str, |
| 156 | default="cosine", | 156 | default="cosine_with_restarts", |
| 157 | help=( | 157 | help=( |
| 158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 159 | ' "constant", "constant_with_warmup"]' | 159 | ' "constant", "constant_with_warmup"]' |
| @@ -162,10 +162,16 @@ def parse_args(): | |||
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| 163 | "--lr_warmup_steps", | 163 | "--lr_warmup_steps", |
| 164 | type=int, | 164 | type=int, |
| 165 | default=500, | 165 | default=300, |
| 166 | help="Number of steps for the warmup in the lr scheduler." | 166 | help="Number of steps for the warmup in the lr scheduler." |
| 167 | ) | 167 | ) |
| 168 | parser.add_argument( | 168 | parser.add_argument( |
| 169 | "--lr_cycles", | ||
| 170 | type=int, | ||
| 171 | default=2, | ||
| 172 | help="Number of restart cycles in the lr scheduler." | ||
| 173 | ) | ||
| 174 | parser.add_argument( | ||
| 169 | "--use_ema", | 175 | "--use_ema", |
| 170 | action="store_true", | 176 | action="store_true", |
| 171 | default=True, | 177 | default=True, |
| @@ -179,7 +185,7 @@ def parse_args(): | |||
| 179 | parser.add_argument( | 185 | parser.add_argument( |
| 180 | "--ema_power", | 186 | "--ema_power", |
| 181 | type=float, | 187 | type=float, |
| 182 | default=6 / 7 | 188 | default=9 / 10 |
| 183 | ) | 189 | ) |
| 184 | parser.add_argument( | 190 | parser.add_argument( |
| 185 | "--ema_max_decay", | 191 | "--ema_max_decay", |
| @@ -565,6 +571,7 @@ def main(): | |||
| 565 | 571 | ||
| 566 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 572 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
| 567 | token_embeds = text_encoder.get_input_embeddings().weight.data | 573 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 574 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | ||
| 568 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 575 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
| 569 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 576 | token_embeds[placeholder_token_id] = initializer_token_embeddings |
| 570 | 577 | ||
| @@ -717,11 +724,10 @@ def main(): | |||
| 717 | 724 | ||
| 718 | if args.lr_scheduler == "cosine_with_restarts": | 725 | if args.lr_scheduler == "cosine_with_restarts": |
| 719 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 726 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 720 | args.lr_scheduler, | ||
| 721 | optimizer=optimizer, | 727 | optimizer=optimizer, |
| 722 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 728 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| 723 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 729 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 724 | num_cycles=num_update_steps_per_epoch, | 730 | num_cycles=args.lr_cycles, |
| 725 | ) | 731 | ) |
| 726 | else: | 732 | else: |
| 727 | lr_scheduler = get_scheduler( | 733 | lr_scheduler = get_scheduler( |
| @@ -857,15 +863,16 @@ def main(): | |||
| 857 | 863 | ||
| 858 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
| 859 | 865 | ||
| 860 | # Zero out the gradients for all token embeddings except the newly added | 866 | # Keep the token embeddings fixed except the newly added |
| 861 | # embeddings for the concept, as we only want to optimize the concept embeddings | 867 | # embeddings for the concept, as we only want to optimize the concept embeddings |
| 862 | if accelerator.num_processes > 1: | 868 | if accelerator.num_processes > 1: |
| 863 | grads = text_encoder.module.get_input_embeddings().weight.grad | 869 | token_embeds = text_encoder.module.get_input_embeddings().weight |
| 864 | else: | 870 | else: |
| 865 | grads = text_encoder.get_input_embeddings().weight.grad | 871 | token_embeds = text_encoder.get_input_embeddings().weight |
| 866 | # Get the index for tokens that we want to zero the grads for | 872 | |
| 867 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id | 873 | # Get the index for tokens that we want to freeze |
| 868 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) | 874 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id |
| 875 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
| 869 | 876 | ||
| 870 | if accelerator.sync_gradients: | 877 | if accelerator.sync_gradients: |
| 871 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 878 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) |
