summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
committerVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
commit83725794618164210a12843381724252fdd82cc2 (patch)
treeec29ade9891fe08dd10b5033214fc09237c2cb86 /train_ti.py
parentImproved learning rate finder (diff)
downloadtextual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip
Integrated updates from diffusers
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py
index 870b2ba..d7696e5 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -25,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem
25from training.optimization import get_one_cycle_schedule 25from training.optimization import get_one_cycle_schedule
26from training.lr import LRFinder 26from training.lr import LRFinder
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -533,12 +533,10 @@ def main():
533 533
534 patch_trainable_embeddings(text_encoder, placeholder_token_id) 534 patch_trainable_embeddings(text_encoder, placeholder_token_id)
535 535
536 freeze_params(itertools.chain( 536 text_encoder.text_model.encoder.requires_grad_(False)
537 text_encoder.text_model.encoder.parameters(), 537 text_encoder.text_model.final_layer_norm.requires_grad_(False)
538 text_encoder.text_model.final_layer_norm.parameters(), 538 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
539 text_encoder.text_model.embeddings.position_embedding.parameters(), 539 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
540 text_encoder.text_model.embeddings.token_embedding.parameters(),
541 ))
542 540
543 prompt_processor = PromptProcessor(tokenizer, text_encoder) 541 prompt_processor = PromptProcessor(tokenizer, text_encoder)
544 542
@@ -548,6 +546,9 @@ def main():
548 args.train_batch_size * accelerator.num_processes 546 args.train_batch_size * accelerator.num_processes
549 ) 547 )
550 548
549 if args.find_lr:
550 args.learning_rate = 1e2
551
551 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 552 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
552 if args.use_8bit_adam: 553 if args.use_8bit_adam:
553 try: 554 try:
@@ -715,7 +716,11 @@ def main():
715 716
716 # Keep vae and unet in eval mode as we don't train these 717 # Keep vae and unet in eval mode as we don't train these
717 vae.eval() 718 vae.eval()
718 unet.eval() 719
720 if args.gradient_checkpointing:
721 unet.train()
722 else:
723 unet.eval()
719 724
720 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 725 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
721 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 726 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)