From eb0838bd2bf96d34dd779f847552291379fe543f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 20:48:04 +0100 Subject: Cleanup --- models/clip/embeddings.py | 11 ----------- models/clip/tokenizer.py | 1 + train_dreambooth.py | 7 ++++--- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8602142..f90e7c2 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -120,14 +120,3 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) text_encoder.text_model.embeddings = text_embeddings return text_embeddings - - -def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: - text_encoder.text_model.embeddings.make_permanent() - - text_embeddings = CLIPTextEmbeddings(text_encoder.config) - text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding - text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding - text_encoder.text_model.embeddings = text_embeddings - - return text_embeddings diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 5e33f3e..bd0bd21 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -57,6 +57,7 @@ class MultiCLIPTokenizerItem(NamedTuple): class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.token_map: dict[int, list[int]] = {} self.vector_shuffle = shuffle_none diff --git a/train_dreambooth.py b/train_dreambooth.py index b07de31..92f9b96 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -17,7 +17,7 @@ from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_ import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel from slugify import slugify from common import load_config, load_embeddings_from_dir @@ -26,7 +26,7 @@ from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule from training.lr import LRFinder from training.util import AverageMeter, CheckpointerBase, save_args -from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings +from models.clip.embeddings import patch_managed_embeddings from models.clip.prompt import PromptProcessor from models.clip.tokenizer import MultiCLIPTokenizer @@ -617,7 +617,8 @@ def main(): if args.train_text_encoder: print(f"Training entire text encoder.") - unpatch_managed_embeddings(text_encoder) + embeddings.make_permanent() + text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) else: print(f"Training added text embeddings") -- cgit v1.2.3-54-g00ecf