summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py101
-rw-r--r--infer.py6
-rw-r--r--textual_inversion.py15
3 files changed, 77 insertions, 45 deletions
diff --git a/dreambooth.py b/dreambooth.py
index da8399f..72c56cd 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -170,14 +170,14 @@ def parse_args():
170 parser.add_argument( 170 parser.add_argument(
171 "--lr_warmup_steps", 171 "--lr_warmup_steps",
172 type=int, 172 type=int,
173 default=300, 173 default=500,
174 help="Number of steps for the warmup in the lr scheduler." 174 help="Number of steps for the warmup in the lr scheduler."
175 ) 175 )
176 parser.add_argument( 176 parser.add_argument(
177 "--lr_cycles", 177 "--lr_cycles",
178 type=int, 178 type=int,
179 default=None, 179 default=None,
180 help="Number of restart cycles in the lr scheduler." 180 help="Number of restart cycles in the lr scheduler (if supported)."
181 ) 181 )
182 parser.add_argument( 182 parser.add_argument(
183 "--use_ema", 183 "--use_ema",
@@ -506,11 +506,10 @@ def main():
506 506
507 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 507 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
508 508
509 save_args(basepath, args) 509 args.seed = args.seed or (torch.random.seed() >> 32)
510 set_seed(args.seed)
510 511
511 # If passed along, set the training seed now. 512 save_args(basepath, args)
512 if args.seed is not None:
513 set_seed(args.seed)
514 513
515 # Load the tokenizer and add the placeholder token as a additional special token 514 # Load the tokenizer and add the placeholder token as a additional special token
516 if args.tokenizer_name: 515 if args.tokenizer_name:
@@ -523,13 +522,22 @@ def main():
523 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 522 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
524 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 523 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
525 524
526 ema_unet = EMAModel( 525 ema_unet = None
527 unet, 526 if args.use_ema:
528 inv_gamma=args.ema_inv_gamma, 527 ema_unet = EMAModel(
529 power=args.ema_power, 528 unet,
530 max_value=args.ema_max_decay, 529 inv_gamma=args.ema_inv_gamma,
531 device=accelerator.device 530 power=args.ema_power,
532 ) if args.use_ema else None 531 max_value=args.ema_max_decay,
532 device=accelerator.device
533 )
534
535 if args.gradient_checkpointing:
536 unet.enable_gradient_checkpointing()
537 text_encoder.gradient_checkpointing_enable()
538
539 # Freeze text_encoder and vae
540 freeze_params(vae.parameters())
533 541
534 if args.initializer_token is not None: 542 if args.initializer_token is not None:
535 # Convert the initializer_token, placeholder_token to ids 543 # Convert the initializer_token, placeholder_token to ids
@@ -545,22 +553,22 @@ def main():
545 print(f"Training new token {args.placeholder_token}.") 553 print(f"Training new token {args.placeholder_token}.")
546 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 554 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
547 555
556 # Resize the token embeddings as we are adding new special tokens to the tokenizer
548 text_encoder.resize_token_embeddings(len(tokenizer)) 557 text_encoder.resize_token_embeddings(len(tokenizer))
549 token_embeds = text_encoder.get_input_embeddings()
550 initializer_token_embeddings = token_embeds(initializer_token_ids)
551 token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings
552
553 prompt_processor = PromptProcessor(tokenizer, text_encoder)
554 558
555 if args.gradient_checkpointing: 559 # Initialise the newly added placeholder token with the embeddings of the initializer token
556 unet.enable_gradient_checkpointing() 560 token_embeds = text_encoder.get_input_embeddings().weight.data
557 text_encoder.gradient_checkpointing_enable() 561 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
562 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
563 token_embeds[placeholder_token_id] = initializer_token_embeddings
558 564
559 # slice_size = unet.config.attention_head_dim // 2 565 freeze_params(itertools.chain(
560 # unet.set_attention_slice(slice_size) 566 text_encoder.text_model.encoder.parameters(),
567 text_encoder.text_model.final_layer_norm.parameters(),
568 text_encoder.text_model.embeddings.position_embedding.parameters(),
569 ))
561 570
562 # Freeze text_encoder and vae 571 prompt_processor = PromptProcessor(tokenizer, text_encoder)
563 freeze_params(vae.parameters())
564 572
565 if args.scale_lr: 573 if args.scale_lr:
566 args.learning_rate_unet = ( 574 args.learning_rate_unet = (
@@ -583,6 +591,11 @@ def main():
583 else: 591 else:
584 optimizer_class = torch.optim.AdamW 592 optimizer_class = torch.optim.AdamW
585 593
594 if args.initializer_token is not None:
595 text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
596 else:
597 text_encoder_params_to_optimize = text_encoder.parameters()
598
586 # Initialize the optimizer 599 # Initialize the optimizer
587 optimizer = optimizer_class( 600 optimizer = optimizer_class(
588 [ 601 [
@@ -591,7 +604,7 @@ def main():
591 'lr': args.learning_rate_unet, 604 'lr': args.learning_rate_unet,
592 }, 605 },
593 { 606 {
594 'params': text_encoder.parameters(), 607 'params': text_encoder_params_to_optimize,
595 'lr': args.learning_rate_text, 608 'lr': args.learning_rate_text,
596 } 609 }
597 ], 610 ],
@@ -849,9 +862,27 @@ def main():
849 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 862 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
850 863
851 accelerator.backward(loss) 864 accelerator.backward(loss)
865
866 if args.initializer_token is not None:
867 # Keep the token embeddings fixed except the newly added
868 # embeddings for the concept, as we only want to optimize the concept embeddings
869 if accelerator.num_processes > 1:
870 token_embeds = text_encoder.module.get_input_embeddings().weight
871 else:
872 token_embeds = text_encoder.get_input_embeddings().weight
873
874 # Get the index for tokens that we want to freeze
875 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
876 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
877
852 if accelerator.sync_gradients: 878 if accelerator.sync_gradients:
853 accelerator.clip_grad_norm_(itertools.chain( 879 params_to_clip = (
854 unet.parameters(), text_encoder.parameters()), args.max_grad_norm) 880 unet.parameters()
881 if args.initializer_token is not None
882 else itertools.chain(unet.parameters(), text_encoder.parameters())
883 )
884 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
885
855 optimizer.step() 886 optimizer.step()
856 if not accelerator.optimizer_step_was_skipped: 887 if not accelerator.optimizer_step_was_skipped:
857 lr_scheduler.step() 888 lr_scheduler.step()
@@ -896,8 +927,8 @@ def main():
896 text_encoder.eval() 927 text_encoder.eval()
897 val_loss = 0.0 928 val_loss = 0.0
898 929
899 for step, batch in enumerate(val_dataloader): 930 with torch.inference_mode():
900 with torch.no_grad(): 931 for step, batch in enumerate(val_dataloader):
901 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 932 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
902 latents = latents * 0.18215 933 latents = latents * 0.18215
903 934
@@ -920,12 +951,12 @@ def main():
920 loss = loss.detach().item() 951 loss = loss.detach().item()
921 val_loss += loss 952 val_loss += loss
922 953
923 if accelerator.sync_gradients: 954 if accelerator.sync_gradients:
924 local_progress_bar.update(1) 955 local_progress_bar.update(1)
925 global_progress_bar.update(1) 956 global_progress_bar.update(1)
926 957
927 logs = {"val/loss": loss} 958 logs = {"val/loss": loss}
928 local_progress_bar.set_postfix(**logs) 959 local_progress_bar.set_postfix(**logs)
929 960
930 val_loss /= len(val_dataloader) 961 val_loss /= len(val_dataloader)
931 962
diff --git a/infer.py b/infer.py
index 8e17c4e..01010eb 100644
--- a/infer.py
+++ b/infer.py
@@ -258,7 +258,7 @@ def generate(output_dir, pipeline, args):
258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
259 output_dir.mkdir(parents=True, exist_ok=True) 259 output_dir.mkdir(parents=True, exist_ok=True)
260 260
261 seed = args.seed or torch.random.seed() 261 args.seed = args.seed or torch.random.seed()
262 262
263 save_args(output_dir, args) 263 save_args(output_dir, args)
264 264
@@ -276,7 +276,7 @@ def generate(output_dir, pipeline, args):
276 dynamic_ncols=True 276 dynamic_ncols=True
277 ) 277 )
278 278
279 generator = torch.Generator(device="cuda").manual_seed(seed + i) 279 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
280 images = pipeline( 280 images = pipeline(
281 prompt=args.prompt * (args.batch_size // len(args.prompt)), 281 prompt=args.prompt * (args.batch_size // len(args.prompt)),
282 height=args.height, 282 height=args.height,
@@ -290,7 +290,7 @@ def generate(output_dir, pipeline, args):
290 ).images 290 ).images
291 291
292 for j, image in enumerate(images): 292 for j, image in enumerate(images):
293 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) 293 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"))
294 294
295 if torch.cuda.is_available(): 295 if torch.cuda.is_available():
296 torch.cuda.empty_cache() 296 torch.cuda.empty_cache()
diff --git a/textual_inversion.py b/textual_inversion.py
index 8f266e0..fe56d36 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -521,6 +521,7 @@ def main():
521 521
522 if args.gradient_checkpointing: 522 if args.gradient_checkpointing:
523 unet.enable_gradient_checkpointing() 523 unet.enable_gradient_checkpointing()
524 text_encoder.gradient_checkpointing_enable()
524 525
525 # slice_size = unet.config.attention_head_dim // 2 526 # slice_size = unet.config.attention_head_dim // 2
526 # unet.set_attention_slice(slice_size) 527 # unet.set_attention_slice(slice_size)
@@ -875,8 +876,8 @@ def main():
875 text_encoder.eval() 876 text_encoder.eval()
876 val_loss = 0.0 877 val_loss = 0.0
877 878
878 for step, batch in enumerate(val_dataloader): 879 with torch.inference_mode():
879 with torch.no_grad(): 880 for step, batch in enumerate(val_dataloader):
880 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 881 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
881 latents = latents * 0.18215 882 latents = latents * 0.18215
882 883
@@ -899,12 +900,12 @@ def main():
899 loss = loss.detach().item() 900 loss = loss.detach().item()
900 val_loss += loss 901 val_loss += loss
901 902
902 if accelerator.sync_gradients: 903 if accelerator.sync_gradients:
903 local_progress_bar.update(1) 904 local_progress_bar.update(1)
904 global_progress_bar.update(1) 905 global_progress_bar.update(1)
905 906
906 logs = {"val/loss": loss} 907 logs = {"val/loss": loss}
907 local_progress_bar.set_postfix(**logs) 908 local_progress_bar.set_postfix(**logs)
908 909
909 val_loss /= len(val_dataloader) 910 val_loss /= len(val_dataloader)
910 911