summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 09:40:34 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 09:40:34 +0100
commitb33ac00de283fe45edba689990dc96a5de93cd1e (patch)
treea3106f2e482f9e4b2ab9d9ff49faf0b529278f50 /dreambooth.py
parentDreambooth: Support loading Textual Inversion embeddings (diff)
downloadtextual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.tar.gz
textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.tar.bz2
textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.zip
Add support for resume in Textual Inversion
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py49
1 files changed, 23 insertions, 26 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 3110c6d..9a6f70a 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -13,7 +13,7 @@ import torch.utils.checkpoint
13from accelerate import Accelerator 13from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from PIL import Image 19from PIL import Image
@@ -204,7 +204,7 @@ def parse_args():
204 parser.add_argument( 204 parser.add_argument(
205 "--lr_warmup_epochs", 205 "--lr_warmup_epochs",
206 type=int, 206 type=int,
207 default=20, 207 default=10,
208 help="Number of steps for the warmup in the lr scheduler." 208 help="Number of steps for the warmup in the lr scheduler."
209 ) 209 )
210 parser.add_argument( 210 parser.add_argument(
@@ -558,11 +558,11 @@ class Checkpointer:
558def main(): 558def main():
559 args = parse_args() 559 args = parse_args()
560 560
561 # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 561 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
562 # raise ValueError( 562 raise ValueError(
563 # "Gradient accumulation is not supported when training the text encoder in distributed training. " 563 "Gradient accumulation is not supported when training the text encoder in distributed training. "
564 # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 564 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
565 # ) 565 )
566 566
567 instance_identifier = args.instance_identifier 567 instance_identifier = args.instance_identifier
568 568
@@ -645,9 +645,9 @@ def main():
645 645
646 print(f"Token ID mappings:") 646 print(f"Token ID mappings:")
647 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 647 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
648 print(f"- {token_id} {token}")
649
650 embedding_file = embeddings_dir.joinpath(f"{token}.bin") 648 embedding_file = embeddings_dir.joinpath(f"{token}.bin")
649 embedding_source = "init"
650
651 if embedding_file.exists() and embedding_file.is_file(): 651 if embedding_file.exists() and embedding_file.is_file():
652 embedding_data = torch.load(embedding_file, map_location="cpu") 652 embedding_data = torch.load(embedding_file, map_location="cpu")
653 653
@@ -656,8 +656,11 @@ def main():
656 emb = emb.unsqueeze(0) 656 emb = emb.unsqueeze(0)
657 657
658 token_embeds[token_id] = emb 658 token_embeds[token_id] = emb
659 embedding_source = "file"
659 660
660 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 661 print(f"- {token_id} {token} ({embedding_source})")
662
663 original_token_embeds = token_embeds.clone().to(accelerator.device)
661 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 664 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
662 665
663 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 666 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
@@ -946,7 +949,7 @@ def main():
946 sample_checkpoint = False 949 sample_checkpoint = False
947 950
948 for step, batch in enumerate(train_dataloader): 951 for step, batch in enumerate(train_dataloader):
949 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 952 with accelerator.accumulate(unet):
950 # Convert images to latent space 953 # Convert images to latent space
951 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 954 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
952 latents = latents * 0.18215 955 latents = latents * 0.18215
@@ -997,16 +1000,6 @@ def main():
997 1000
998 accelerator.backward(loss) 1001 accelerator.backward(loss)
999 1002
1000 if not args.train_text_encoder:
1001 # Keep the token embeddings fixed except the newly added
1002 # embeddings for the concept, as we only want to optimize the concept embeddings
1003 if accelerator.num_processes > 1:
1004 token_embeds = text_encoder.module.get_input_embeddings().weight
1005 else:
1006 token_embeds = text_encoder.get_input_embeddings().weight
1007
1008 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
1009
1010 if accelerator.sync_gradients: 1003 if accelerator.sync_gradients:
1011 params_to_clip = ( 1004 params_to_clip = (
1012 itertools.chain(unet.parameters(), text_encoder.parameters()) 1005 itertools.chain(unet.parameters(), text_encoder.parameters())
@@ -1022,6 +1015,12 @@ def main():
1022 ema_unet.step(unet) 1015 ema_unet.step(unet)
1023 optimizer.zero_grad(set_to_none=True) 1016 optimizer.zero_grad(set_to_none=True)
1024 1017
1018 if not args.train_text_encoder:
1019 # Let's make sure we don't update any embedding weights besides the newly added token
1020 with torch.no_grad():
1021 text_encoder.get_input_embeddings(
1022 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
1023
1025 avg_loss.update(loss.detach_(), bsz) 1024 avg_loss.update(loss.detach_(), bsz)
1026 avg_acc.update(acc.detach_(), bsz) 1025 avg_acc.update(acc.detach_(), bsz)
1027 1026
@@ -1032,9 +1031,6 @@ def main():
1032 1031
1033 global_step += 1 1032 global_step += 1
1034 1033
1035 if global_step % args.sample_frequency == 0:
1036 sample_checkpoint = True
1037
1038 logs = { 1034 logs = {
1039 "train/loss": avg_loss.avg.item(), 1035 "train/loss": avg_loss.avg.item(),
1040 "train/acc": avg_acc.avg.item(), 1036 "train/acc": avg_acc.avg.item(),
@@ -1117,8 +1113,9 @@ def main():
1117 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 1113 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1118 max_acc_val = avg_acc_val.avg.item() 1114 max_acc_val = avg_acc_val.avg.item()
1119 1115
1120 if sample_checkpoint and accelerator.is_main_process: 1116 if accelerator.is_main_process:
1121 checkpointer.save_samples(global_step, args.sample_steps) 1117 if epoch % args.sample_frequency == 0:
1118 checkpointer.save_samples(global_step, args.sample_steps)
1122 1119
1123 # Create the pipeline using using the trained modules and save it. 1120 # Create the pipeline using using the trained modules and save it.
1124 if accelerator.is_main_process: 1121 if accelerator.is_main_process: