diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-01 22:01:47 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-01 22:01:47 +0100 |
| commit | 7c02c2fe68da2411623f0a11c1187ccf0f7743d8 (patch) | |
| tree | 106eddc16374eaa80966782168ab41c6c191145e /textual_inversion.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.gz textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.bz2 textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.zip | |
Update
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index d6be522..80f1d7d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -545,6 +545,7 @@ def main(): | |||
| 545 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 545 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 546 | args.pretrained_model_name_or_path, subfolder='scheduler') | 546 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 547 | 547 | ||
| 548 | vae.enable_slicing() | ||
| 548 | unet.set_use_memory_efficient_attention_xformers(True) | 549 | unet.set_use_memory_efficient_attention_xformers(True) |
| 549 | 550 | ||
| 550 | if args.gradient_checkpointing: | 551 | if args.gradient_checkpointing: |
| @@ -814,6 +815,9 @@ def main(): | |||
| 814 | ) | 815 | ) |
| 815 | global_progress_bar.set_description("Total progress") | 816 | global_progress_bar.set_description("Total progress") |
| 816 | 817 | ||
| 818 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 819 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 820 | |||
| 817 | try: | 821 | try: |
| 818 | for epoch in range(num_epochs): | 822 | for epoch in range(num_epochs): |
| 819 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 823 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -827,7 +831,7 @@ def main(): | |||
| 827 | for step, batch in enumerate(train_dataloader): | 831 | for step, batch in enumerate(train_dataloader): |
| 828 | with accelerator.accumulate(text_encoder): | 832 | with accelerator.accumulate(text_encoder): |
| 829 | # Convert images to latent space | 833 | # Convert images to latent space |
| 830 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 834 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| 831 | latents = latents * 0.18215 | 835 | latents = latents * 0.18215 |
| 832 | 836 | ||
| 833 | # Sample noise that we'll add to the latents | 837 | # Sample noise that we'll add to the latents |
| @@ -883,7 +887,6 @@ def main(): | |||
| 883 | token_embeds = text_encoder.get_input_embeddings().weight | 887 | token_embeds = text_encoder.get_input_embeddings().weight |
| 884 | 888 | ||
| 885 | # Get the index for tokens that we want to freeze | 889 | # Get the index for tokens that we want to freeze |
| 886 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
| 887 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | 890 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] |
| 888 | 891 | ||
| 889 | optimizer.step() | 892 | optimizer.step() |
| @@ -927,8 +930,6 @@ def main(): | |||
| 927 | 930 | ||
| 928 | accelerator.wait_for_everyone() | 931 | accelerator.wait_for_everyone() |
| 929 | 932 | ||
| 930 | print(token_embeds[placeholder_token_id]) | ||
| 931 | |||
| 932 | text_encoder.eval() | 933 | text_encoder.eval() |
| 933 | val_loss = 0.0 | 934 | val_loss = 0.0 |
| 934 | 935 | ||
