summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 20:48:04 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 20:48:04 +0100
commiteb0838bd2bf96d34dd779f847552291379fe543f (patch)
tree501c41a8330a06ee0b0939a47ae74c281129ab47 /train_dreambooth.py
parentFix MultiCLIPTokenizer (forgot to override encode) (diff)
downloadtextual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.tar.gz
textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.tar.bz2
textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.zip
Cleanup
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index b07de31..92f9b96 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -17,7 +17,7 @@ from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_
17import matplotlib.pyplot as plt 17import matplotlib.pyplot as plt
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm 19from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_config, load_embeddings_from_dir 23from common import load_config, load_embeddings_from_dir
@@ -26,7 +26,7 @@ from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
30from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
31from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
32 32
@@ -617,7 +617,8 @@ def main():
617 if args.train_text_encoder: 617 if args.train_text_encoder:
618 print(f"Training entire text encoder.") 618 print(f"Training entire text encoder.")
619 619
620 unpatch_managed_embeddings(text_encoder) 620 embeddings.make_permanent()
621 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
621 else: 622 else:
622 print(f"Training added text embeddings") 623 print(f"Training added text embeddings")
623 624