summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-24 14:30:15 +0100
committerVolpeon <git@volpeon.ink>2022-12-24 14:30:15 +0100
commitd2105d96fdd18da035d2ad412e3fb6f579d5571a (patch)
treef6b5ff7f817875bcb086e88e7b4a9eebd537adbe /train_ti.py
parentTraining update (diff)
downloadtextual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.gz
textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.bz2
textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.zip
Fixed Textual Inversion
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index a12b889..5f37d54 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -515,10 +515,16 @@ def main():
515 515
516 vae.requires_grad_(False) 516 vae.requires_grad_(False)
517 unet.requires_grad_(False) 517 unet.requires_grad_(False)
518 text_encoder.requires_grad_(False)
519 518
520 patch_trainable_embeddings(text_encoder, placeholder_token_id) 519 patch_trainable_embeddings(text_encoder, placeholder_token_id)
521 520
521 freeze_params(itertools.chain(
522 text_encoder.text_model.encoder.parameters(),
523 text_encoder.text_model.final_layer_norm.parameters(),
524 text_encoder.text_model.embeddings.position_embedding.parameters(),
525 text_encoder.text_model.embeddings.token_embedding.parameters(),
526 ))
527
522 prompt_processor = PromptProcessor(tokenizer, text_encoder) 528 prompt_processor = PromptProcessor(tokenizer, text_encoder)
523 529
524 if args.scale_lr: 530 if args.scale_lr: