diff options
-rw-r--r-- | dreambooth.py | 41 | ||||
-rw-r--r-- | dreambooth_plus.py | 33 | ||||
-rw-r--r-- | textual_inversion.py | 28 |
3 files changed, 69 insertions, 33 deletions
diff --git a/dreambooth.py b/dreambooth.py index 1ba8dc0..9e2645b 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -15,7 +15,7 @@ from accelerate import Accelerator | |||
15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
19 | from diffusers.training_utils import EMAModel | 19 | from diffusers.training_utils import EMAModel |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
@@ -150,10 +150,16 @@ def parse_args(): | |||
150 | parser.add_argument( | 150 | parser.add_argument( |
151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
152 | type=int, | 152 | type=int, |
153 | default=500, | 153 | default=300, |
154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
155 | ) | 155 | ) |
156 | parser.add_argument( | 156 | parser.add_argument( |
157 | "--lr_cycles", | ||
158 | type=int, | ||
159 | default=2, | ||
160 | help="Number of restart cycles in the lr scheduler." | ||
161 | ) | ||
162 | parser.add_argument( | ||
157 | "--use_ema", | 163 | "--use_ema", |
158 | action="store_true", | 164 | action="store_true", |
159 | default=True, | 165 | default=True, |
@@ -167,7 +173,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 173 | parser.add_argument( |
168 | "--ema_power", | 174 | "--ema_power", |
169 | type=float, | 175 | type=float, |
170 | default=6 / 7 | 176 | default=9 / 10 |
171 | ) | 177 | ) |
172 | parser.add_argument( | 178 | parser.add_argument( |
173 | "--ema_max_decay", | 179 | "--ema_max_decay", |
@@ -296,6 +302,13 @@ def parse_args(): | |||
296 | return args | 302 | return args |
297 | 303 | ||
298 | 304 | ||
305 | def save_args(basepath: Path, args, extra={}): | ||
306 | info = {"args": vars(args)} | ||
307 | info["args"].update(extra) | ||
308 | with open(basepath.joinpath("args.json"), "w") as f: | ||
309 | json.dump(info, f, indent=4) | ||
310 | |||
311 | |||
299 | def freeze_params(params): | 312 | def freeze_params(params): |
300 | for param in params: | 313 | for param in params: |
301 | param.requires_grad = False | 314 | param.requires_grad = False |
@@ -455,6 +468,8 @@ def main(): | |||
455 | 468 | ||
456 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 469 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
457 | 470 | ||
471 | save_args(basepath, args) | ||
472 | |||
458 | # If passed along, set the training seed now. | 473 | # If passed along, set the training seed now. |
459 | if args.seed is not None: | 474 | if args.seed is not None: |
460 | set_seed(args.seed) | 475 | set_seed(args.seed) |
@@ -614,12 +629,20 @@ def main(): | |||
614 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 629 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
615 | overrode_max_train_steps = True | 630 | overrode_max_train_steps = True |
616 | 631 | ||
617 | lr_scheduler = get_scheduler( | 632 | if args.lr_scheduler == "cosine_with_restarts": |
618 | args.lr_scheduler, | 633 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
619 | optimizer=optimizer, | 634 | optimizer=optimizer, |
620 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 635 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
621 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 636 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
622 | ) | 637 | num_cycles=args.lr_cycles, |
638 | ) | ||
639 | else: | ||
640 | lr_scheduler = get_scheduler( | ||
641 | args.lr_scheduler, | ||
642 | optimizer=optimizer, | ||
643 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
644 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
645 | ) | ||
623 | 646 | ||
624 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 647 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
625 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 648 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index eeee424..42994af 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -118,7 +118,7 @@ def parse_args(): | |||
118 | parser.add_argument( | 118 | parser.add_argument( |
119 | "--max_train_steps", | 119 | "--max_train_steps", |
120 | type=int, | 120 | type=int, |
121 | default=1300, | 121 | default=1200, |
122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
123 | ) | 123 | ) |
124 | parser.add_argument( | 124 | parser.add_argument( |
@@ -141,7 +141,7 @@ def parse_args(): | |||
141 | parser.add_argument( | 141 | parser.add_argument( |
142 | "--learning_rate_text", | 142 | "--learning_rate_text", |
143 | type=float, | 143 | type=float, |
144 | default=5e-6, | 144 | default=1e-6, |
145 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
146 | ) | 146 | ) |
147 | parser.add_argument( | 147 | parser.add_argument( |
@@ -153,7 +153,7 @@ def parse_args(): | |||
153 | parser.add_argument( | 153 | parser.add_argument( |
154 | "--lr_scheduler", | 154 | "--lr_scheduler", |
155 | type=str, | 155 | type=str, |
156 | default="cosine", | 156 | default="cosine_with_restarts", |
157 | help=( | 157 | help=( |
158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 158 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
159 | ' "constant", "constant_with_warmup"]' | 159 | ' "constant", "constant_with_warmup"]' |
@@ -162,10 +162,16 @@ def parse_args(): | |||
162 | parser.add_argument( | 162 | parser.add_argument( |
163 | "--lr_warmup_steps", | 163 | "--lr_warmup_steps", |
164 | type=int, | 164 | type=int, |
165 | default=500, | 165 | default=300, |
166 | help="Number of steps for the warmup in the lr scheduler." | 166 | help="Number of steps for the warmup in the lr scheduler." |
167 | ) | 167 | ) |
168 | parser.add_argument( | 168 | parser.add_argument( |
169 | "--lr_cycles", | ||
170 | type=int, | ||
171 | default=2, | ||
172 | help="Number of restart cycles in the lr scheduler." | ||
173 | ) | ||
174 | parser.add_argument( | ||
169 | "--use_ema", | 175 | "--use_ema", |
170 | action="store_true", | 176 | action="store_true", |
171 | default=True, | 177 | default=True, |
@@ -179,7 +185,7 @@ def parse_args(): | |||
179 | parser.add_argument( | 185 | parser.add_argument( |
180 | "--ema_power", | 186 | "--ema_power", |
181 | type=float, | 187 | type=float, |
182 | default=6 / 7 | 188 | default=9 / 10 |
183 | ) | 189 | ) |
184 | parser.add_argument( | 190 | parser.add_argument( |
185 | "--ema_max_decay", | 191 | "--ema_max_decay", |
@@ -565,6 +571,7 @@ def main(): | |||
565 | 571 | ||
566 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 572 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
567 | token_embeds = text_encoder.get_input_embeddings().weight.data | 573 | token_embeds = text_encoder.get_input_embeddings().weight.data |
574 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | ||
568 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 575 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
569 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 576 | token_embeds[placeholder_token_id] = initializer_token_embeddings |
570 | 577 | ||
@@ -717,11 +724,10 @@ def main(): | |||
717 | 724 | ||
718 | if args.lr_scheduler == "cosine_with_restarts": | 725 | if args.lr_scheduler == "cosine_with_restarts": |
719 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 726 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
720 | args.lr_scheduler, | ||
721 | optimizer=optimizer, | 727 | optimizer=optimizer, |
722 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 728 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
723 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 729 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
724 | num_cycles=num_update_steps_per_epoch, | 730 | num_cycles=args.lr_cycles, |
725 | ) | 731 | ) |
726 | else: | 732 | else: |
727 | lr_scheduler = get_scheduler( | 733 | lr_scheduler = get_scheduler( |
@@ -857,15 +863,16 @@ def main(): | |||
857 | 863 | ||
858 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
859 | 865 | ||
860 | # Zero out the gradients for all token embeddings except the newly added | 866 | # Keep the token embeddings fixed except the newly added |
861 | # embeddings for the concept, as we only want to optimize the concept embeddings | 867 | # embeddings for the concept, as we only want to optimize the concept embeddings |
862 | if accelerator.num_processes > 1: | 868 | if accelerator.num_processes > 1: |
863 | grads = text_encoder.module.get_input_embeddings().weight.grad | 869 | token_embeds = text_encoder.module.get_input_embeddings().weight |
864 | else: | 870 | else: |
865 | grads = text_encoder.get_input_embeddings().weight.grad | 871 | token_embeds = text_encoder.get_input_embeddings().weight |
866 | # Get the index for tokens that we want to zero the grads for | 872 | |
867 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id | 873 | # Get the index for tokens that we want to freeze |
868 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) | 874 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id |
875 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
869 | 876 | ||
870 | if accelerator.sync_gradients: | 877 | if accelerator.sync_gradients: |
871 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) | 878 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) |
diff --git a/textual_inversion.py b/textual_inversion.py index 2109d13..61c96b7 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -155,10 +155,16 @@ def parse_args(): | |||
155 | parser.add_argument( | 155 | parser.add_argument( |
156 | "--lr_warmup_steps", | 156 | "--lr_warmup_steps", |
157 | type=int, | 157 | type=int, |
158 | default=500, | 158 | default=300, |
159 | help="Number of steps for the warmup in the lr scheduler." | 159 | help="Number of steps for the warmup in the lr scheduler." |
160 | ) | 160 | ) |
161 | parser.add_argument( | 161 | parser.add_argument( |
162 | "--lr_cycles", | ||
163 | type=int, | ||
164 | default=15, | ||
165 | help="Number of restart cycles in the lr scheduler." | ||
166 | ) | ||
167 | parser.add_argument( | ||
162 | "--use_8bit_adam", | 168 | "--use_8bit_adam", |
163 | action="store_true", | 169 | action="store_true", |
164 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 170 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -515,13 +521,13 @@ def main(): | |||
515 | 521 | ||
516 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 522 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
517 | token_embeds = text_encoder.get_input_embeddings().weight.data | 523 | token_embeds = text_encoder.get_input_embeddings().weight.data |
518 | 524 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | |
519 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
520 | 525 | ||
521 | if args.resume_checkpoint is not None: | 526 | if args.resume_checkpoint is not None: |
522 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 527 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ |
523 | args.placeholder_token] | 528 | args.placeholder_token] |
524 | else: | 529 | else: |
530 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
525 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 531 | token_embeds[placeholder_token_id] = initializer_token_embeddings |
526 | 532 | ||
527 | # Freeze vae and unet | 533 | # Freeze vae and unet |
@@ -662,11 +668,10 @@ def main(): | |||
662 | 668 | ||
663 | if args.lr_scheduler == "cosine_with_restarts": | 669 | if args.lr_scheduler == "cosine_with_restarts": |
664 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 670 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
665 | args.lr_scheduler, | ||
666 | optimizer=optimizer, | 671 | optimizer=optimizer, |
667 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | 672 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
668 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 673 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
669 | num_cycles=num_update_steps_per_epoch, | 674 | num_cycles=args.lr_cycles, |
670 | ) | 675 | ) |
671 | else: | 676 | else: |
672 | lr_scheduler = get_scheduler( | 677 | lr_scheduler = get_scheduler( |
@@ -803,15 +808,16 @@ def main(): | |||
803 | 808 | ||
804 | accelerator.backward(loss) | 809 | accelerator.backward(loss) |
805 | 810 | ||
806 | # Zero out the gradients for all token embeddings except the newly added | 811 | # Keep the token embeddings fixed except the newly added |
807 | # embeddings for the concept, as we only want to optimize the concept embeddings | 812 | # embeddings for the concept, as we only want to optimize the concept embeddings |
808 | if accelerator.num_processes > 1: | 813 | if accelerator.num_processes > 1: |
809 | grads = text_encoder.module.get_input_embeddings().weight.grad | 814 | token_embeds = text_encoder.module.get_input_embeddings().weight |
810 | else: | 815 | else: |
811 | grads = text_encoder.get_input_embeddings().weight.grad | 816 | token_embeds = text_encoder.get_input_embeddings().weight |
812 | # Get the index for tokens that we want to zero the grads for | 817 | |
813 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id | 818 | # Get the index for tokens that we want to freeze |
814 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) | 819 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id |
820 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
815 | 821 | ||
816 | optimizer.step() | 822 | optimizer.step() |
817 | if not accelerator.optimizer_step_was_skipped: | 823 | if not accelerator.optimizer_step_was_skipped: |