summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py11
-rw-r--r--models/clip/tokenizer.py1
-rw-r--r--train_dreambooth.py7
3 files changed, 5 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 8602142..f90e7c2 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -120,14 +120,3 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe
120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) 120 text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings)
121 text_encoder.text_model.embeddings = text_embeddings 121 text_encoder.text_model.embeddings = text_embeddings
122 return text_embeddings 122 return text_embeddings
123
124
125def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings:
126 text_encoder.text_model.embeddings.make_permanent()
127
128 text_embeddings = CLIPTextEmbeddings(text_encoder.config)
129 text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding
130 text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding
131 text_encoder.text_model.embeddings = text_embeddings
132
133 return text_embeddings
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 5e33f3e..bd0bd21 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -57,6 +57,7 @@ class MultiCLIPTokenizerItem(NamedTuple):
57class MultiCLIPTokenizer(CLIPTokenizer): 57class MultiCLIPTokenizer(CLIPTokenizer):
58 def __init__(self, *args, **kwargs): 58 def __init__(self, *args, **kwargs):
59 super().__init__(*args, **kwargs) 59 super().__init__(*args, **kwargs)
60
60 self.token_map: dict[int, list[int]] = {} 61 self.token_map: dict[int, list[int]] = {}
61 self.vector_shuffle = shuffle_none 62 self.vector_shuffle = shuffle_none
62 63
diff --git a/train_dreambooth.py b/train_dreambooth.py
index b07de31..92f9b96 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -17,7 +17,7 @@ from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_
17import matplotlib.pyplot as plt 17import matplotlib.pyplot as plt
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm 19from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_config, load_embeddings_from_dir 23from common import load_config, load_embeddings_from_dir
@@ -26,7 +26,7 @@ from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
30from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
31from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
32 32
@@ -617,7 +617,8 @@ def main():
617 if args.train_text_encoder: 617 if args.train_text_encoder:
618 print(f"Training entire text encoder.") 618 print(f"Training entire text encoder.")
619 619
620 unpatch_managed_embeddings(text_encoder) 620 embeddings.make_permanent()
621 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
621 else: 622 else:
622 print(f"Training added text embeddings") 623 print(f"Training added text embeddings")
623 624