diff options
author | Volpeon <git@volpeon.ink> | 2022-12-24 14:30:15 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-24 14:30:15 +0100 |
commit | d2105d96fdd18da035d2ad412e3fb6f579d5571a (patch) | |
tree | f6b5ff7f817875bcb086e88e7b4a9eebd537adbe /train_ti.py | |
parent | Training update (diff) | |
download | textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.gz textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.tar.bz2 textual-inversion-diff-d2105d96fdd18da035d2ad412e3fb6f579d5571a.zip |
Fixed Textual Inversion
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index a12b889..5f37d54 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params |
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | 30 | ||
31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -515,10 +515,16 @@ def main(): | |||
515 | 515 | ||
516 | vae.requires_grad_(False) | 516 | vae.requires_grad_(False) |
517 | unet.requires_grad_(False) | 517 | unet.requires_grad_(False) |
518 | text_encoder.requires_grad_(False) | ||
519 | 518 | ||
520 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 519 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
521 | 520 | ||
521 | freeze_params(itertools.chain( | ||
522 | text_encoder.text_model.encoder.parameters(), | ||
523 | text_encoder.text_model.final_layer_norm.parameters(), | ||
524 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
525 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
526 | )) | ||
527 | |||
522 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 528 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
523 | 529 | ||
524 | if args.scale_lr: | 530 | if args.scale_lr: |