diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-01 22:01:47 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-01 22:01:47 +0100 |
| commit | 7c02c2fe68da2411623f0a11c1187ccf0f7743d8 (patch) | |
| tree | 106eddc16374eaa80966782168ab41c6c191145e /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.gz textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.tar.bz2 textual-inversion-diff-7c02c2fe68da2411623f0a11c1187ccf0f7743d8.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index 31dbea2..1ead6dd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -550,11 +550,11 @@ class Checkpointer: | |||
| 550 | def main(): | 550 | def main(): |
| 551 | args = parse_args() | 551 | args = parse_args() |
| 552 | 552 | ||
| 553 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | 553 | # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: |
| 554 | raise ValueError( | 554 | # raise ValueError( |
| 555 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | 555 | # "Gradient accumulation is not supported when training the text encoder in distributed training. " |
| 556 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 556 | # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
| 557 | ) | 557 | # ) |
| 558 | 558 | ||
| 559 | instance_identifier = args.instance_identifier | 559 | instance_identifier = args.instance_identifier |
| 560 | 560 | ||
| @@ -899,6 +899,9 @@ def main(): | |||
| 899 | ) | 899 | ) |
| 900 | global_progress_bar.set_description("Total progress") | 900 | global_progress_bar.set_description("Total progress") |
| 901 | 901 | ||
| 902 | index_fixed_tokens = torch.arange(len(tokenizer)) | ||
| 903 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | ||
| 904 | |||
| 902 | try: | 905 | try: |
| 903 | for epoch in range(num_epochs): | 906 | for epoch in range(num_epochs): |
| 904 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 907 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -910,7 +913,7 @@ def main(): | |||
| 910 | sample_checkpoint = False | 913 | sample_checkpoint = False |
| 911 | 914 | ||
| 912 | for step, batch in enumerate(train_dataloader): | 915 | for step, batch in enumerate(train_dataloader): |
| 913 | with accelerator.accumulate(unet): | 916 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
| 914 | # Convert images to latent space | 917 | # Convert images to latent space |
| 915 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 918 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 916 | latents = latents * 0.18215 | 919 | latents = latents * 0.18215 |
| @@ -967,8 +970,6 @@ def main(): | |||
| 967 | else: | 970 | else: |
| 968 | token_embeds = text_encoder.get_input_embeddings().weight | 971 | token_embeds = text_encoder.get_input_embeddings().weight |
| 969 | 972 | ||
| 970 | # Get the index for tokens that we want to freeze | ||
| 971 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
| 972 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | 973 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] |
| 973 | 974 | ||
| 974 | if accelerator.sync_gradients: | 975 | if accelerator.sync_gradients: |
