diff options
author | Volpeon <git@volpeon.ink> | 2022-12-13 09:40:34 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-13 09:40:34 +0100 |
commit | b33ac00de283fe45edba689990dc96a5de93cd1e (patch) | |
tree | a3106f2e482f9e4b2ab9d9ff49faf0b529278f50 /dreambooth.py | |
parent | Dreambooth: Support loading Textual Inversion embeddings (diff) | |
download | textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.tar.gz textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.tar.bz2 textual-inversion-diff-b33ac00de283fe45edba689990dc96a5de93cd1e.zip |
Add support for resume in Textual Inversion
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 49 |
1 files changed, 23 insertions, 26 deletions
diff --git a/dreambooth.py b/dreambooth.py index 3110c6d..9a6f70a 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -13,7 +13,7 @@ import torch.utils.checkpoint | |||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from PIL import Image | 19 | from PIL import Image |
@@ -204,7 +204,7 @@ def parse_args(): | |||
204 | parser.add_argument( | 204 | parser.add_argument( |
205 | "--lr_warmup_epochs", | 205 | "--lr_warmup_epochs", |
206 | type=int, | 206 | type=int, |
207 | default=20, | 207 | default=10, |
208 | help="Number of steps for the warmup in the lr scheduler." | 208 | help="Number of steps for the warmup in the lr scheduler." |
209 | ) | 209 | ) |
210 | parser.add_argument( | 210 | parser.add_argument( |
@@ -558,11 +558,11 @@ class Checkpointer: | |||
558 | def main(): | 558 | def main(): |
559 | args = parse_args() | 559 | args = parse_args() |
560 | 560 | ||
561 | # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | 561 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: |
562 | # raise ValueError( | 562 | raise ValueError( |
563 | # "Gradient accumulation is not supported when training the text encoder in distributed training. " | 563 | "Gradient accumulation is not supported when training the text encoder in distributed training. " |
564 | # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 564 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
565 | # ) | 565 | ) |
566 | 566 | ||
567 | instance_identifier = args.instance_identifier | 567 | instance_identifier = args.instance_identifier |
568 | 568 | ||
@@ -645,9 +645,9 @@ def main(): | |||
645 | 645 | ||
646 | print(f"Token ID mappings:") | 646 | print(f"Token ID mappings:") |
647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 647 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
648 | print(f"- {token_id} {token}") | ||
649 | |||
650 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") | 648 | embedding_file = embeddings_dir.joinpath(f"{token}.bin") |
649 | embedding_source = "init" | ||
650 | |||
651 | if embedding_file.exists() and embedding_file.is_file(): | 651 | if embedding_file.exists() and embedding_file.is_file(): |
652 | embedding_data = torch.load(embedding_file, map_location="cpu") | 652 | embedding_data = torch.load(embedding_file, map_location="cpu") |
653 | 653 | ||
@@ -656,8 +656,11 @@ def main(): | |||
656 | emb = emb.unsqueeze(0) | 656 | emb = emb.unsqueeze(0) |
657 | 657 | ||
658 | token_embeds[token_id] = emb | 658 | token_embeds[token_id] = emb |
659 | embedding_source = "file" | ||
659 | 660 | ||
660 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 661 | print(f"- {token_id} {token} ({embedding_source})") |
662 | |||
663 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
661 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 664 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
662 | 665 | ||
663 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 666 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
@@ -946,7 +949,7 @@ def main(): | |||
946 | sample_checkpoint = False | 949 | sample_checkpoint = False |
947 | 950 | ||
948 | for step, batch in enumerate(train_dataloader): | 951 | for step, batch in enumerate(train_dataloader): |
949 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | 952 | with accelerator.accumulate(unet): |
950 | # Convert images to latent space | 953 | # Convert images to latent space |
951 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 954 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
952 | latents = latents * 0.18215 | 955 | latents = latents * 0.18215 |
@@ -997,16 +1000,6 @@ def main(): | |||
997 | 1000 | ||
998 | accelerator.backward(loss) | 1001 | accelerator.backward(loss) |
999 | 1002 | ||
1000 | if not args.train_text_encoder: | ||
1001 | # Keep the token embeddings fixed except the newly added | ||
1002 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
1003 | if accelerator.num_processes > 1: | ||
1004 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
1005 | else: | ||
1006 | token_embeds = text_encoder.get_input_embeddings().weight | ||
1007 | |||
1008 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
1009 | |||
1010 | if accelerator.sync_gradients: | 1003 | if accelerator.sync_gradients: |
1011 | params_to_clip = ( | 1004 | params_to_clip = ( |
1012 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 1005 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
@@ -1022,6 +1015,12 @@ def main(): | |||
1022 | ema_unet.step(unet) | 1015 | ema_unet.step(unet) |
1023 | optimizer.zero_grad(set_to_none=True) | 1016 | optimizer.zero_grad(set_to_none=True) |
1024 | 1017 | ||
1018 | if not args.train_text_encoder: | ||
1019 | # Let's make sure we don't update any embedding weights besides the newly added token | ||
1020 | with torch.no_grad(): | ||
1021 | text_encoder.get_input_embeddings( | ||
1022 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
1023 | |||
1025 | avg_loss.update(loss.detach_(), bsz) | 1024 | avg_loss.update(loss.detach_(), bsz) |
1026 | avg_acc.update(acc.detach_(), bsz) | 1025 | avg_acc.update(acc.detach_(), bsz) |
1027 | 1026 | ||
@@ -1032,9 +1031,6 @@ def main(): | |||
1032 | 1031 | ||
1033 | global_step += 1 | 1032 | global_step += 1 |
1034 | 1033 | ||
1035 | if global_step % args.sample_frequency == 0: | ||
1036 | sample_checkpoint = True | ||
1037 | |||
1038 | logs = { | 1034 | logs = { |
1039 | "train/loss": avg_loss.avg.item(), | 1035 | "train/loss": avg_loss.avg.item(), |
1040 | "train/acc": avg_acc.avg.item(), | 1036 | "train/acc": avg_acc.avg.item(), |
@@ -1117,8 +1113,9 @@ def main(): | |||
1117 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 1113 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
1118 | max_acc_val = avg_acc_val.avg.item() | 1114 | max_acc_val = avg_acc_val.avg.item() |
1119 | 1115 | ||
1120 | if sample_checkpoint and accelerator.is_main_process: | 1116 | if accelerator.is_main_process: |
1121 | checkpointer.save_samples(global_step, args.sample_steps) | 1117 | if epoch % args.sample_frequency == 0: |
1118 | checkpointer.save_samples(global_step, args.sample_steps) | ||
1122 | 1119 | ||
1123 | # Create the pipeline using using the trained modules and save it. | 1120 | # Create the pipeline using using the trained modules and save it. |
1124 | if accelerator.is_main_process: | 1121 | if accelerator.is_main_process: |