summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
committerVolpeon <git@volpeon.ink>2022-12-28 18:08:36 +0100
commit83725794618164210a12843381724252fdd82cc2 (patch)
treeec29ade9891fe08dd10b5033214fc09237c2cb86 /train_dreambooth.py
parentImproved learning rate finder (diff)
downloadtextual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2
textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip
Integrated updates from diffusers
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py16
1 files changed, 6 insertions, 10 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 325fe90..202d52c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.ti import patch_trainable_embeddings 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -580,12 +580,10 @@ def main():
580 580
581 patch_trainable_embeddings(text_encoder, placeholder_token_id) 581 patch_trainable_embeddings(text_encoder, placeholder_token_id)
582 582
583 freeze_params(itertools.chain( 583 text_encoder.text_model.encoder.requires_grad_(False)
584 text_encoder.text_model.encoder.parameters(), 584 text_encoder.text_model.final_layer_norm.requires_grad_(False)
585 text_encoder.text_model.final_layer_norm.parameters(), 585 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
586 text_encoder.text_model.embeddings.position_embedding.parameters(), 586 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
587 text_encoder.text_model.embeddings.token_embedding.parameters(),
588 ))
589 587
590 prompt_processor = PromptProcessor(tokenizer, text_encoder) 588 prompt_processor = PromptProcessor(tokenizer, text_encoder)
591 589
@@ -905,9 +903,7 @@ def main():
905 if epoch < args.train_text_encoder_epochs: 903 if epoch < args.train_text_encoder_epochs:
906 text_encoder.train() 904 text_encoder.train()
907 elif epoch == args.train_text_encoder_epochs: 905 elif epoch == args.train_text_encoder_epochs:
908 freeze_params(text_encoder.parameters()) 906 text_encoder.requires_grad_(False)
909
910 sample_checkpoint = False
911 907
912 for step, batch in enumerate(train_dataloader): 908 for step, batch in enumerate(train_dataloader):
913 with accelerator.accumulate(unet): 909 with accelerator.accumulate(unet):