diff options
| -rw-r--r-- | models/clip/embeddings.py | 11 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 1 | ||||
| -rw-r--r-- | train_dreambooth.py | 7 |
3 files changed, 5 insertions, 14 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8602142..f90e7c2 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -120,14 +120,3 @@ def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbe | |||
| 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) | 120 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
| 121 | text_encoder.text_model.embeddings = text_embeddings | 121 | text_encoder.text_model.embeddings = text_embeddings |
| 122 | return text_embeddings | 122 | return text_embeddings |
| 123 | |||
| 124 | |||
| 125 | def unpatch_managed_embeddings(text_encoder: CLIPTextModel) -> CLIPTextEmbeddings: | ||
| 126 | text_encoder.text_model.embeddings.make_permanent() | ||
| 127 | |||
| 128 | text_embeddings = CLIPTextEmbeddings(text_encoder.config) | ||
| 129 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
| 130 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
| 131 | text_encoder.text_model.embeddings = text_embeddings | ||
| 132 | |||
| 133 | return text_embeddings | ||
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 5e33f3e..bd0bd21 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -57,6 +57,7 @@ class MultiCLIPTokenizerItem(NamedTuple): | |||
| 57 | class MultiCLIPTokenizer(CLIPTokenizer): | 57 | class MultiCLIPTokenizer(CLIPTokenizer): |
| 58 | def __init__(self, *args, **kwargs): | 58 | def __init__(self, *args, **kwargs): |
| 59 | super().__init__(*args, **kwargs) | 59 | super().__init__(*args, **kwargs) |
| 60 | |||
| 60 | self.token_map: dict[int, list[int]] = {} | 61 | self.token_map: dict[int, list[int]] = {} |
| 61 | self.vector_shuffle = shuffle_none | 62 | self.vector_shuffle = shuffle_none |
| 62 | 63 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index b07de31..92f9b96 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -17,7 +17,7 @@ from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_ | |||
| 17 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
| 18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
| 19 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| 20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_config, load_embeddings_from_dir | 23 | from common import load_config, load_embeddings_from_dir |
| @@ -26,7 +26,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
| 30 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
| 31 | from models.clip.tokenizer import MultiCLIPTokenizer | 31 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 32 | 32 | ||
| @@ -617,7 +617,8 @@ def main(): | |||
| 617 | if args.train_text_encoder: | 617 | if args.train_text_encoder: |
| 618 | print(f"Training entire text encoder.") | 618 | print(f"Training entire text encoder.") |
| 619 | 619 | ||
| 620 | unpatch_managed_embeddings(text_encoder) | 620 | embeddings.make_permanent() |
| 621 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
| 621 | else: | 622 | else: |
| 622 | print(f"Training added text embeddings") | 623 | print(f"Training added text embeddings") |
| 623 | 624 | ||
