summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py41
-rw-r--r--dreambooth_plus.py33
-rw-r--r--textual_inversion.py28
3 files changed, 69 insertions, 33 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 1ba8dc0..9e2645b 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -15,7 +15,7 @@ from accelerate import Accelerator
15from accelerate.logging import get_logger 15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
19from diffusers.training_utils import EMAModel 19from diffusers.training_utils import EMAModel
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
@@ -150,10 +150,16 @@ def parse_args():
150 parser.add_argument( 150 parser.add_argument(
151 "--lr_warmup_steps", 151 "--lr_warmup_steps",
152 type=int, 152 type=int,
153 default=500, 153 default=300,
154 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
155 ) 155 )
156 parser.add_argument( 156 parser.add_argument(
157 "--lr_cycles",
158 type=int,
159 default=2,
160 help="Number of restart cycles in the lr scheduler."
161 )
162 parser.add_argument(
157 "--use_ema", 163 "--use_ema",
158 action="store_true", 164 action="store_true",
159 default=True, 165 default=True,
@@ -167,7 +173,7 @@ def parse_args():
167 parser.add_argument( 173 parser.add_argument(
168 "--ema_power", 174 "--ema_power",
169 type=float, 175 type=float,
170 default=6 / 7 176 default=9 / 10
171 ) 177 )
172 parser.add_argument( 178 parser.add_argument(
173 "--ema_max_decay", 179 "--ema_max_decay",
@@ -296,6 +302,13 @@ def parse_args():
296 return args 302 return args
297 303
298 304
305def save_args(basepath: Path, args, extra={}):
306 info = {"args": vars(args)}
307 info["args"].update(extra)
308 with open(basepath.joinpath("args.json"), "w") as f:
309 json.dump(info, f, indent=4)
310
311
299def freeze_params(params): 312def freeze_params(params):
300 for param in params: 313 for param in params:
301 param.requires_grad = False 314 param.requires_grad = False
@@ -455,6 +468,8 @@ def main():
455 468
456 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 469 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
457 470
471 save_args(basepath, args)
472
458 # If passed along, set the training seed now. 473 # If passed along, set the training seed now.
459 if args.seed is not None: 474 if args.seed is not None:
460 set_seed(args.seed) 475 set_seed(args.seed)
@@ -614,12 +629,20 @@ def main():
614 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 629 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
615 overrode_max_train_steps = True 630 overrode_max_train_steps = True
616 631
617 lr_scheduler = get_scheduler( 632 if args.lr_scheduler == "cosine_with_restarts":
618 args.lr_scheduler, 633 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
619 optimizer=optimizer, 634 optimizer=optimizer,
620 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 635 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
621 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 636 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
622 ) 637 num_cycles=args.lr_cycles,
638 )
639 else:
640 lr_scheduler = get_scheduler(
641 args.lr_scheduler,
642 optimizer=optimizer,
643 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
644 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
645 )
623 646
624 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 647 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
625 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 648 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index eeee424..42994af 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -118,7 +118,7 @@ def parse_args():
118 parser.add_argument( 118 parser.add_argument(
119 "--max_train_steps", 119 "--max_train_steps",
120 type=int, 120 type=int,
121 default=1300, 121 default=1200,
122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
123 ) 123 )
124 parser.add_argument( 124 parser.add_argument(
@@ -141,7 +141,7 @@ def parse_args():
141 parser.add_argument( 141 parser.add_argument(
142 "--learning_rate_text", 142 "--learning_rate_text",
143 type=float, 143 type=float,
144 default=5e-6, 144 default=1e-6,
145 help="Initial learning rate (after the potential warmup period) to use.", 145 help="Initial learning rate (after the potential warmup period) to use.",
146 ) 146 )
147 parser.add_argument( 147 parser.add_argument(
@@ -153,7 +153,7 @@ def parse_args():
153 parser.add_argument( 153 parser.add_argument(
154 "--lr_scheduler", 154 "--lr_scheduler",
155 type=str, 155 type=str,
156 default="cosine", 156 default="cosine_with_restarts",
157 help=( 157 help=(
158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
159 ' "constant", "constant_with_warmup"]' 159 ' "constant", "constant_with_warmup"]'
@@ -162,10 +162,16 @@ def parse_args():
162 parser.add_argument( 162 parser.add_argument(
163 "--lr_warmup_steps", 163 "--lr_warmup_steps",
164 type=int, 164 type=int,
165 default=500, 165 default=300,
166 help="Number of steps for the warmup in the lr scheduler." 166 help="Number of steps for the warmup in the lr scheduler."
167 ) 167 )
168 parser.add_argument( 168 parser.add_argument(
169 "--lr_cycles",
170 type=int,
171 default=2,
172 help="Number of restart cycles in the lr scheduler."
173 )
174 parser.add_argument(
169 "--use_ema", 175 "--use_ema",
170 action="store_true", 176 action="store_true",
171 default=True, 177 default=True,
@@ -179,7 +185,7 @@ def parse_args():
179 parser.add_argument( 185 parser.add_argument(
180 "--ema_power", 186 "--ema_power",
181 type=float, 187 type=float,
182 default=6 / 7 188 default=9 / 10
183 ) 189 )
184 parser.add_argument( 190 parser.add_argument(
185 "--ema_max_decay", 191 "--ema_max_decay",
@@ -565,6 +571,7 @@ def main():
565 571
566 # Initialise the newly added placeholder token with the embeddings of the initializer token 572 # Initialise the newly added placeholder token with the embeddings of the initializer token
567 token_embeds = text_encoder.get_input_embeddings().weight.data 573 token_embeds = text_encoder.get_input_embeddings().weight.data
574 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
568 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 575 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
569 token_embeds[placeholder_token_id] = initializer_token_embeddings 576 token_embeds[placeholder_token_id] = initializer_token_embeddings
570 577
@@ -717,11 +724,10 @@ def main():
717 724
718 if args.lr_scheduler == "cosine_with_restarts": 725 if args.lr_scheduler == "cosine_with_restarts":
719 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 726 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
720 args.lr_scheduler,
721 optimizer=optimizer, 727 optimizer=optimizer,
722 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 728 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
723 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 729 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
724 num_cycles=num_update_steps_per_epoch, 730 num_cycles=args.lr_cycles,
725 ) 731 )
726 else: 732 else:
727 lr_scheduler = get_scheduler( 733 lr_scheduler = get_scheduler(
@@ -857,15 +863,16 @@ def main():
857 863
858 accelerator.backward(loss) 864 accelerator.backward(loss)
859 865
860 # Zero out the gradients for all token embeddings except the newly added 866 # Keep the token embeddings fixed except the newly added
861 # embeddings for the concept, as we only want to optimize the concept embeddings 867 # embeddings for the concept, as we only want to optimize the concept embeddings
862 if accelerator.num_processes > 1: 868 if accelerator.num_processes > 1:
863 grads = text_encoder.module.get_input_embeddings().weight.grad 869 token_embeds = text_encoder.module.get_input_embeddings().weight
864 else: 870 else:
865 grads = text_encoder.get_input_embeddings().weight.grad 871 token_embeds = text_encoder.get_input_embeddings().weight
866 # Get the index for tokens that we want to zero the grads for 872
867 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id 873 # Get the index for tokens that we want to freeze
868 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 874 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
875 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
869 876
870 if accelerator.sync_gradients: 877 if accelerator.sync_gradients:
871 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 878 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
diff --git a/textual_inversion.py b/textual_inversion.py
index 2109d13..61c96b7 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -155,10 +155,16 @@ def parse_args():
155 parser.add_argument( 155 parser.add_argument(
156 "--lr_warmup_steps", 156 "--lr_warmup_steps",
157 type=int, 157 type=int,
158 default=500, 158 default=300,
159 help="Number of steps for the warmup in the lr scheduler." 159 help="Number of steps for the warmup in the lr scheduler."
160 ) 160 )
161 parser.add_argument( 161 parser.add_argument(
162 "--lr_cycles",
163 type=int,
164 default=15,
165 help="Number of restart cycles in the lr scheduler."
166 )
167 parser.add_argument(
162 "--use_8bit_adam", 168 "--use_8bit_adam",
163 action="store_true", 169 action="store_true",
164 help="Whether or not to use 8-bit Adam from bitsandbytes." 170 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -515,13 +521,13 @@ def main():
515 521
516 # Initialise the newly added placeholder token with the embeddings of the initializer token 522 # Initialise the newly added placeholder token with the embeddings of the initializer token
517 token_embeds = text_encoder.get_input_embeddings().weight.data 523 token_embeds = text_encoder.get_input_embeddings().weight.data
518 524 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
519 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
520 525
521 if args.resume_checkpoint is not None: 526 if args.resume_checkpoint is not None:
522 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ 527 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
523 args.placeholder_token] 528 args.placeholder_token]
524 else: 529 else:
530 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
525 token_embeds[placeholder_token_id] = initializer_token_embeddings 531 token_embeds[placeholder_token_id] = initializer_token_embeddings
526 532
527 # Freeze vae and unet 533 # Freeze vae and unet
@@ -662,11 +668,10 @@ def main():
662 668
663 if args.lr_scheduler == "cosine_with_restarts": 669 if args.lr_scheduler == "cosine_with_restarts":
664 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 670 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
665 args.lr_scheduler,
666 optimizer=optimizer, 671 optimizer=optimizer,
667 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 672 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
668 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 673 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
669 num_cycles=num_update_steps_per_epoch, 674 num_cycles=args.lr_cycles,
670 ) 675 )
671 else: 676 else:
672 lr_scheduler = get_scheduler( 677 lr_scheduler = get_scheduler(
@@ -803,15 +808,16 @@ def main():
803 808
804 accelerator.backward(loss) 809 accelerator.backward(loss)
805 810
806 # Zero out the gradients for all token embeddings except the newly added 811 # Keep the token embeddings fixed except the newly added
807 # embeddings for the concept, as we only want to optimize the concept embeddings 812 # embeddings for the concept, as we only want to optimize the concept embeddings
808 if accelerator.num_processes > 1: 813 if accelerator.num_processes > 1:
809 grads = text_encoder.module.get_input_embeddings().weight.grad 814 token_embeds = text_encoder.module.get_input_embeddings().weight
810 else: 815 else:
811 grads = text_encoder.get_input_embeddings().weight.grad 816 token_embeds = text_encoder.get_input_embeddings().weight
812 # Get the index for tokens that we want to zero the grads for 817
813 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id 818 # Get the index for tokens that we want to freeze
814 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) 819 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
820 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
815 821
816 optimizer.step() 822 optimizer.step()
817 if not accelerator.optimizer_step_was_skipped: 823 if not accelerator.optimizer_step_was_skipped: