summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-01 22:01:47 +0100
committerVolpeon <git@volpeon.ink>2022-12-01 22:01:47 +0100
commit7c02c2fe68da2411623f0a11c1187ccf0f7743d8 (patch)
tree106eddc16374eaa80966782168ab41c6c191145e /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.gz
textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.bz2
textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 31dbea2..1ead6dd 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -550,11 +550,11 @@ class Checkpointer:
550def main(): 550def main():
551 args = parse_args() 551 args = parse_args()
552 552
553 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 553 # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
554 raise ValueError( 554 # raise ValueError(
555 "Gradient accumulation is not supported when training the text encoder in distributed training. " 555 # "Gradient accumulation is not supported when training the text encoder in distributed training. "
556 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 556 # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
557 ) 557 # )
558 558
559 instance_identifier = args.instance_identifier 559 instance_identifier = args.instance_identifier
560 560
@@ -899,6 +899,9 @@ def main():
899 ) 899 )
900 global_progress_bar.set_description("Total progress") 900 global_progress_bar.set_description("Total progress")
901 901
902 index_fixed_tokens = torch.arange(len(tokenizer))
903 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
904
902 try: 905 try:
903 for epoch in range(num_epochs): 906 for epoch in range(num_epochs):
904 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 907 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -910,7 +913,7 @@ def main():
910 sample_checkpoint = False 913 sample_checkpoint = False
911 914
912 for step, batch in enumerate(train_dataloader): 915 for step, batch in enumerate(train_dataloader):
913 with accelerator.accumulate(unet): 916 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
914 # Convert images to latent space 917 # Convert images to latent space
915 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 918 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
916 latents = latents * 0.18215 919 latents = latents * 0.18215
@@ -967,8 +970,6 @@ def main():
967 else: 970 else:
968 token_embeds = text_encoder.get_input_embeddings().weight 971 token_embeds = text_encoder.get_input_embeddings().weight
969 972
970 # Get the index for tokens that we want to freeze
971 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
972 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] 973 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
973 974
974 if accelerator.sync_gradients: 975 if accelerator.sync_gradients: