diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
| commit | 83725794618164210a12843381724252fdd82cc2 (patch) | |
| tree | ec29ade9891fe08dd10b5033214fc09237c2cb86 /train_dreambooth.py | |
| parent | Improved learning rate finder (diff) | |
| download | textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2 textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip | |
Integrated updates from diffusers
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 16 |
1 files changed, 6 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 325fe90..202d52c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -580,12 +580,10 @@ def main(): | |||
| 580 | 580 | ||
| 581 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 581 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 582 | 582 | ||
| 583 | freeze_params(itertools.chain( | 583 | text_encoder.text_model.encoder.requires_grad_(False) |
| 584 | text_encoder.text_model.encoder.parameters(), | 584 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 585 | text_encoder.text_model.final_layer_norm.parameters(), | 585 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 586 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 586 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 587 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 588 | )) | ||
| 589 | 587 | ||
| 590 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 588 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 591 | 589 | ||
| @@ -905,9 +903,7 @@ def main(): | |||
| 905 | if epoch < args.train_text_encoder_epochs: | 903 | if epoch < args.train_text_encoder_epochs: |
| 906 | text_encoder.train() | 904 | text_encoder.train() |
| 907 | elif epoch == args.train_text_encoder_epochs: | 905 | elif epoch == args.train_text_encoder_epochs: |
| 908 | freeze_params(text_encoder.parameters()) | 906 | text_encoder.requires_grad_(False) |
| 909 | |||
| 910 | sample_checkpoint = False | ||
| 911 | 907 | ||
| 912 | for step, batch in enumerate(train_dataloader): | 908 | for step, batch in enumerate(train_dataloader): |
| 913 | with accelerator.accumulate(unet): | 909 | with accelerator.accumulate(unet): |
