summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-21 09:50:46 +0200
committerVolpeon <git@volpeon.ink>2022-10-21 09:50:46 +0200
commitca914af018632b6231fb3ee4fcd5cdbdc467c784 (patch)
tree01af701c5ac740518cdbc4001592a3f9a29cc57a /dreambooth.py
parentDreambooth: Added option to insert a new input token; removed Dreambooth Plus (diff)
downloadtextual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.gz
textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.bz2
textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.zip
Add optional TI functionality to Dreambooth
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py101
1 files changed, 66 insertions, 35 deletions
diff --git a/dreambooth.py b/dreambooth.py
index da8399f..72c56cd 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -170,14 +170,14 @@ def parse_args():
170 parser.add_argument( 170 parser.add_argument(
171 "--lr_warmup_steps", 171 "--lr_warmup_steps",
172 type=int, 172 type=int,
173 default=300, 173 default=500,
174 help="Number of steps for the warmup in the lr scheduler." 174 help="Number of steps for the warmup in the lr scheduler."
175 ) 175 )
176 parser.add_argument( 176 parser.add_argument(
177 "--lr_cycles", 177 "--lr_cycles",
178 type=int, 178 type=int,
179 default=None, 179 default=None,
180 help="Number of restart cycles in the lr scheduler." 180 help="Number of restart cycles in the lr scheduler (if supported)."
181 ) 181 )
182 parser.add_argument( 182 parser.add_argument(
183 "--use_ema", 183 "--use_ema",
@@ -506,11 +506,10 @@ def main():
506 506
507 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 507 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
508 508
509 save_args(basepath, args) 509 args.seed = args.seed or (torch.random.seed() >> 32)
510 set_seed(args.seed)
510 511
511 # If passed along, set the training seed now. 512 save_args(basepath, args)
512 if args.seed is not None:
513 set_seed(args.seed)
514 513
515 # Load the tokenizer and add the placeholder token as a additional special token 514 # Load the tokenizer and add the placeholder token as a additional special token
516 if args.tokenizer_name: 515 if args.tokenizer_name:
@@ -523,13 +522,22 @@ def main():
523 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 522 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
524 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 523 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
525 524
526 ema_unet = EMAModel( 525 ema_unet = None
527 unet, 526 if args.use_ema:
528 inv_gamma=args.ema_inv_gamma, 527 ema_unet = EMAModel(
529 power=args.ema_power, 528 unet,
530 max_value=args.ema_max_decay, 529 inv_gamma=args.ema_inv_gamma,
531 device=accelerator.device 530 power=args.ema_power,
532 ) if args.use_ema else None 531 max_value=args.ema_max_decay,
532 device=accelerator.device
533 )
534
535 if args.gradient_checkpointing:
536 unet.enable_gradient_checkpointing()
537 text_encoder.gradient_checkpointing_enable()
538
539 # Freeze text_encoder and vae
540 freeze_params(vae.parameters())
533 541
534 if args.initializer_token is not None: 542 if args.initializer_token is not None:
535 # Convert the initializer_token, placeholder_token to ids 543 # Convert the initializer_token, placeholder_token to ids
@@ -545,22 +553,22 @@ def main():
545 print(f"Training new token {args.placeholder_token}.") 553 print(f"Training new token {args.placeholder_token}.")
546 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 554 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
547 555
556 # Resize the token embeddings as we are adding new special tokens to the tokenizer
548 text_encoder.resize_token_embeddings(len(tokenizer)) 557 text_encoder.resize_token_embeddings(len(tokenizer))
549 token_embeds = text_encoder.get_input_embeddings()
550 initializer_token_embeddings = token_embeds(initializer_token_ids)
551 token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings
552
553 prompt_processor = PromptProcessor(tokenizer, text_encoder)
554 558
555 if args.gradient_checkpointing: 559 # Initialise the newly added placeholder token with the embeddings of the initializer token
556 unet.enable_gradient_checkpointing() 560 token_embeds = text_encoder.get_input_embeddings().weight.data
557 text_encoder.gradient_checkpointing_enable() 561 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
562 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
563 token_embeds[placeholder_token_id] = initializer_token_embeddings
558 564
559 # slice_size = unet.config.attention_head_dim // 2 565 freeze_params(itertools.chain(
560 # unet.set_attention_slice(slice_size) 566 text_encoder.text_model.encoder.parameters(),
567 text_encoder.text_model.final_layer_norm.parameters(),
568 text_encoder.text_model.embeddings.position_embedding.parameters(),
569 ))
561 570
562 # Freeze text_encoder and vae 571 prompt_processor = PromptProcessor(tokenizer, text_encoder)
563 freeze_params(vae.parameters())
564 572
565 if args.scale_lr: 573 if args.scale_lr:
566 args.learning_rate_unet = ( 574 args.learning_rate_unet = (
@@ -583,6 +591,11 @@ def main():
583 else: 591 else:
584 optimizer_class = torch.optim.AdamW 592 optimizer_class = torch.optim.AdamW
585 593
594 if args.initializer_token is not None:
595 text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
596 else:
597 text_encoder_params_to_optimize = text_encoder.parameters()
598
586 # Initialize the optimizer 599 # Initialize the optimizer
587 optimizer = optimizer_class( 600 optimizer = optimizer_class(
588 [ 601 [
@@ -591,7 +604,7 @@ def main():
591 'lr': args.learning_rate_unet, 604 'lr': args.learning_rate_unet,
592 }, 605 },
593 { 606 {
594 'params': text_encoder.parameters(), 607 'params': text_encoder_params_to_optimize,
595 'lr': args.learning_rate_text, 608 'lr': args.learning_rate_text,
596 } 609 }
597 ], 610 ],
@@ -849,9 +862,27 @@ def main():
849 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 862 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
850 863
851 accelerator.backward(loss) 864 accelerator.backward(loss)
865
866 if args.initializer_token is not None:
867 # Keep the token embeddings fixed except the newly added
868 # embeddings for the concept, as we only want to optimize the concept embeddings
869 if accelerator.num_processes > 1:
870 token_embeds = text_encoder.module.get_input_embeddings().weight
871 else:
872 token_embeds = text_encoder.get_input_embeddings().weight
873
874 # Get the index for tokens that we want to freeze
875 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
876 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
877
852 if accelerator.sync_gradients: 878 if accelerator.sync_gradients:
853 accelerator.clip_grad_norm_(itertools.chain( 879 params_to_clip = (
854 unet.parameters(), text_encoder.parameters()), args.max_grad_norm) 880 unet.parameters()
881 if args.initializer_token is not None
882 else itertools.chain(unet.parameters(), text_encoder.parameters())
883 )
884 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
885
855 optimizer.step() 886 optimizer.step()
856 if not accelerator.optimizer_step_was_skipped: 887 if not accelerator.optimizer_step_was_skipped:
857 lr_scheduler.step() 888 lr_scheduler.step()
@@ -896,8 +927,8 @@ def main():
896 text_encoder.eval() 927 text_encoder.eval()
897 val_loss = 0.0 928 val_loss = 0.0
898 929
899 for step, batch in enumerate(val_dataloader): 930 with torch.inference_mode():
900 with torch.no_grad(): 931 for step, batch in enumerate(val_dataloader):
901 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 932 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
902 latents = latents * 0.18215 933 latents = latents * 0.18215
903 934
@@ -920,12 +951,12 @@ def main():
920 loss = loss.detach().item() 951 loss = loss.detach().item()
921 val_loss += loss 952 val_loss += loss
922 953
923 if accelerator.sync_gradients: 954 if accelerator.sync_gradients:
924 local_progress_bar.update(1) 955 local_progress_bar.update(1)
925 global_progress_bar.update(1) 956 global_progress_bar.update(1)
926 957
927 logs = {"val/loss": loss} 958 logs = {"val/loss": loss}
928 local_progress_bar.set_postfix(**logs) 959 local_progress_bar.set_postfix(**logs)
929 960
930 val_loss /= len(val_dataloader) 961 val_loss /= len(val_dataloader)
931 962