From 30e072fc795ea36eb92ba25e433b557ff650e35a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 1 Dec 2022 22:04:10 +0100 Subject: Update --- dreambooth.py | 6 +++--- textual_inversion.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 1ead6dd..f3f722e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -651,6 +651,9 @@ def main(): text_encoder.text_model.embeddings.position_embedding.parameters(), )) + index_fixed_tokens = torch.arange(len(tokenizer)) + index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: @@ -899,9 +902,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - index_fixed_tokens = torch.arange(len(tokenizer)) - index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") diff --git a/textual_inversion.py b/textual_inversion.py index 80f1d7d..cd2d22b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -578,6 +578,9 @@ def main(): for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): token_embeds[token_id] = embeddings + index_fixed_tokens = torch.arange(len(tokenizer)) + index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] + # Freeze vae and unet freeze_params(vae.parameters()) freeze_params(unet.parameters()) @@ -815,9 +818,6 @@ def main(): ) global_progress_bar.set_description("Total progress") - index_fixed_tokens = torch.arange(len(tokenizer)) - index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] - try: for epoch in range(num_epochs): local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") -- cgit v1.2.3-70-g09d2