From eb0838bd2bf96d34dd779f847552291379fe543f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 20:48:04 +0100 Subject: Cleanup --- train_dreambooth.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index b07de31..92f9b96 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -17,7 +17,7 @@ from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_ import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel from slugify import slugify from common import load_config, load_embeddings_from_dir @@ -26,7 +26,7 @@ from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args -from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings +from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer @@ -617,7 +617,8 @@ def main(): if args.train_text_encoder: print(f"Training entire text encoder.") - unpatch_managed_embeddings(text_encoder) + embeddings.make_permanent() + text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) else: print(f"Training added text embeddings") -- cgit v1.2.3-54-g00ecf