diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
| commit | 83725794618164210a12843381724252fdd82cc2 (patch) | |
| tree | ec29ade9891fe08dd10b5033214fc09237c2cb86 /train_ti.py | |
| parent | Improved learning rate finder (diff) | |
| download | textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2 textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip | |
Integrated updates from diffusers
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py index 870b2ba..d7696e5 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
| 25 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
| 26 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -533,12 +533,10 @@ def main(): | |||
| 533 | 533 | ||
| 534 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 534 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 535 | 535 | ||
| 536 | freeze_params(itertools.chain( | 536 | text_encoder.text_model.encoder.requires_grad_(False) |
| 537 | text_encoder.text_model.encoder.parameters(), | 537 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 538 | text_encoder.text_model.final_layer_norm.parameters(), | 538 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 539 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 539 | text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) |
| 540 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 541 | )) | ||
| 542 | 540 | ||
| 543 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 541 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 544 | 542 | ||
| @@ -548,6 +546,9 @@ def main(): | |||
| 548 | args.train_batch_size * accelerator.num_processes | 546 | args.train_batch_size * accelerator.num_processes |
| 549 | ) | 547 | ) |
| 550 | 548 | ||
| 549 | if args.find_lr: | ||
| 550 | args.learning_rate = 1e2 | ||
| 551 | |||
| 551 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 552 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 552 | if args.use_8bit_adam: | 553 | if args.use_8bit_adam: |
| 553 | try: | 554 | try: |
| @@ -715,7 +716,11 @@ def main(): | |||
| 715 | 716 | ||
| 716 | # Keep vae and unet in eval mode as we don't train these | 717 | # Keep vae and unet in eval mode as we don't train these |
| 717 | vae.eval() | 718 | vae.eval() |
| 718 | unet.eval() | 719 | |
| 720 | if args.gradient_checkpointing: | ||
| 721 | unet.train() | ||
| 722 | else: | ||
| 723 | unet.eval() | ||
| 719 | 724 | ||
| 720 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 725 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 721 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 726 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
