summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-16 14:39:39 +0200
committerVolpeon <git@volpeon.ink>2022-10-16 14:39:39 +0200
commitdee4c7135754543f1eb7ea616ee3847d34a85b51 (patch)
tree4064b44bb79e499cf6a8f1ec38a83a4889f067a7 /dreambooth_plus.py
parentUpdate (diff)
downloadtextual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.gz
textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.tar.bz2
textual-inversion-diff-dee4c7135754543f1eb7ea616ee3847d34a85b51.zip
Update
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index eeee424..42994af 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -118,7 +118,7 @@ def parse_args():
118 parser.add_argument( 118 parser.add_argument(
119 "--max_train_steps", 119 "--max_train_steps",
120 type=int, 120 type=int,
121 default=1300, 121 default=1200,
122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123 ) 123 )
124 parser.add_argument( 124 parser.add_argument(
@@ -141,7 +141,7 @@ def parse_args():
141 parser.add_argument( 141 parser.add_argument(
142 "--learning_rate_text", 142 "--learning_rate_text",
143 type=float, 143 type=float,
144 default=5e-6, 144 default=1e-6,
145 help="Initial learning rate (after the potential warmup period) to use.", 145 help="Initial learning rate (after the potential warmup period) to use.",
146 ) 146 )
147 parser.add_argument( 147 parser.add_argument(
@@ -153,7 +153,7 @@ def parse_args():
153 parser.add_argument( 153 parser.add_argument(
154 "--lr_scheduler", 154 "--lr_scheduler",
155 type=str, 155 type=str,
156 default="cosine", 156 default="cosine_with_restarts",
157 help=( 157 help=(
158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
159 ' "constant", "constant_with_warmup"]' 159 ' "constant", "constant_with_warmup"]'
@@ -162,10 +162,16 @@ def parse_args():
162 parser.add_argument( 162 parser.add_argument(
163 "--lr_warmup_steps", 163 "--lr_warmup_steps",
164 type=int, 164 type=int,
165 default=500, 165 default=300,
166 help="Number of steps for the warmup in the lr scheduler." 166 help="Number of steps for the warmup in the lr scheduler."
167 ) 167 )
168 parser.add_argument( 168 parser.add_argument(
169 "--lr_cycles",
170 type=int,
171 default=2,
172 help="Number of restart cycles in the lr scheduler."
173 )
174 parser.add_argument(
169 "--use_ema", 175 "--use_ema",
170 action="store_true", 176 action="store_true",
171 default=True, 177 default=True,
@@ -179,7 +185,7 @@ def parse_args():
179 parser.add_argument( 185 parser.add_argument(
180 "--ema_power", 186 "--ema_power",
181 type=float, 187 type=float,
182 default=6 / 7 188 default=9 / 10
183 ) 189 )
184 parser.add_argument( 190 parser.add_argument(
185 "--ema_max_decay", 191 "--ema_max_decay",
@@ -565,6 +571,7 @@ def main():
565 571
566 # Initialise the newly added placeholder token with the embeddings of the initializer token 572 # Initialise the newly added placeholder token with the embeddings of the initializer token
567 token_embeds = text_encoder.get_input_embeddings().weight.data 573 token_embeds = text_encoder.get_input_embeddings().weight.data
574 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
568 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 575 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
569 token_embeds[placeholder_token_id] = initializer_token_embeddings 576 token_embeds[placeholder_token_id] = initializer_token_embeddings
570 577
@@ -717,11 +724,10 @@ def main():
717 724
718 if args.lr_scheduler == "cosine_with_restarts": 725 if args.lr_scheduler == "cosine_with_restarts":
719 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 726 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
720 args.lr_scheduler,
721 optimizer=optimizer, 727 optimizer=optimizer,
722 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 728 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
723 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 729 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
724 num_cycles=num_update_steps_per_epoch, 730 num_cycles=args.lr_cycles,
725 ) 731 )
726 else: 732 else:
727 lr_scheduler = get_scheduler( 733 lr_scheduler = get_scheduler(
@@ -857,15 +863,16 @@ def main():
857 863
858 accelerator.backward(loss) 864 accelerator.backward(loss)
859 865
860 # Zero out the gradients for all token embeddings except the newly added 866 # Keep the token embeddings fixed except the newly added
861 # embeddings for the concept, as we only want to optimize the concept embeddings 867 # embeddings for the concept, as we only want to optimize the concept embeddings
862 if accelerator.num_processes > 1: 868 if accelerator.num_processes > 1:
863 grads = text_encoder.module.get_input_embeddings().weight.grad 869 token_embeds = text_encoder.module.get_input_embeddings().weight
864 else: 870 else:
865 grads = text_encoder.get_input_embeddings().weight.grad 871 token_embeds = text_encoder.get_input_embeddings().weight
866 # Get the index for tokens that we want to zero the grads for 872
867 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id 873 # Get the index for tokens that we want to freeze
868 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 874 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
875 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
869 876
870 if accelerator.sync_gradients: 877 if accelerator.sync_gradients:
871 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 878 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)