diff options
author | Volpeon <git@volpeon.ink> | 2023-01-01 20:48:04 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-01 20:48:04 +0100 |
commit | eb0838bd2bf96d34dd779f847552291379fe543f (patch) | |
tree | 501c41a8330a06ee0b0939a47ae74c281129ab47 /train_dreambooth.py | |
parent | Fix MultiCLIPTokenizer (forgot to override encode) (diff) | |
download | textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.tar.gz textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.tar.bz2 textual-inversion-diff-eb0838bd2bf96d34dd779f847552291379fe543f.zip |
Cleanup
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index b07de31..92f9b96 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -17,7 +17,7 @@ from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_ | |||
17 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_config, load_embeddings_from_dir | 23 | from common import load_config, load_embeddings_from_dir |
@@ -26,7 +26,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings, unpatch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
30 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
31 | from models.clip.tokenizer import MultiCLIPTokenizer | 31 | from models.clip.tokenizer import MultiCLIPTokenizer |
32 | 32 | ||
@@ -617,7 +617,8 @@ def main(): | |||
617 | if args.train_text_encoder: | 617 | if args.train_text_encoder: |
618 | print(f"Training entire text encoder.") | 618 | print(f"Training entire text encoder.") |
619 | 619 | ||
620 | unpatch_managed_embeddings(text_encoder) | 620 | embeddings.make_permanent() |
621 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
621 | else: | 622 | else: |
622 | print(f"Training added text embeddings") | 623 | print(f"Training added text embeddings") |
623 | 624 | ||