summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
committerVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
commit83725794618164210a12843381724252fdd82cc2 (patch)
treeec29ade9891fe08dd10b5033214fc09237c2cb86 /train_lora.py
parentImproved learning rate finder (diff)
downloadtextual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip
Integrated updates from diffusers
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py
index ffca304..9a42cae 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.lora import LoraAttnProcessor 26from training.lora import LoraAttnProcessor
27from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -513,11 +513,9 @@ def main():
513 513
514 print(f"Training added text embeddings") 514 print(f"Training added text embeddings")
515 515
516 freeze_params(itertools.chain( 516 text_encoder.text_model.encoder.requires_grad_(False)
517 text_encoder.text_model.encoder.parameters(), 517 text_encoder.text_model.final_layer_norm.requires_grad_(False)
518 text_encoder.text_model.final_layer_norm.parameters(), 518 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
519 text_encoder.text_model.embeddings.position_embedding.parameters(),
520 ))
521 519
522 index_fixed_tokens = torch.arange(len(tokenizer)) 520 index_fixed_tokens = torch.arange(len(tokenizer))
523 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] 521 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]