diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
commit | f00877a13bce50b02cfc3790f2d18a325e9ff95b (patch) | |
tree | ebbda04024081e9c3c00400fae98124f3db2cc9c /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2 textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index a4e2dde..78c1b5c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -11,20 +11,16 @@ import torch.utils.checkpoint | |||
11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
14 | from diffusers import AutoencoderKL, UNet2DConditionModel | ||
15 | import matplotlib.pyplot as plt | 14 | import matplotlib.pyplot as plt |
16 | from transformers import CLIPTextModel | ||
17 | from slugify import slugify | 15 | from slugify import slugify |
18 | 16 | ||
19 | from util import load_config, load_embeddings_from_dir | 17 | from util import load_config, load_embeddings_from_dir |
20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
21 | from data.csv import VlpnDataModule, VlpnDataItem | 18 | from data.csv import VlpnDataModule, VlpnDataItem |
22 | from trainer.base import Checkpointer | 19 | from trainer_old.base import Checkpointer |
23 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 20 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
24 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
25 | from training.lr import LRFinder | 22 | from training.lr import LRFinder |
26 | from training.util import EMAModel, save_args | 23 | from training.util import EMAModel, save_args |
27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
28 | 24 | ||
29 | logger = get_logger(__name__) | 25 | logger = get_logger(__name__) |
30 | 26 | ||
@@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer): | |||
485 | def __init__( | 481 | def __init__( |
486 | self, | 482 | self, |
487 | ema_embeddings: EMAModel, | 483 | ema_embeddings: EMAModel, |
484 | placeholder_tokens: list[str], | ||
485 | placeholder_token_ids: list[list[int]], | ||
488 | *args, | 486 | *args, |
489 | **kwargs, | 487 | **kwargs, |
490 | ): | 488 | ): |
491 | super().__init__(*args, **kwargs) | 489 | super().__init__(*args, **kwargs) |
492 | 490 | ||
493 | self.ema_embeddings = ema_embeddings | 491 | self.ema_embeddings = ema_embeddings |
492 | self.placeholder_tokens = placeholder_tokens | ||
493 | self.placeholder_token_ids = placeholder_token_ids | ||
494 | 494 | ||
495 | @torch.no_grad() | 495 | @torch.no_grad() |
496 | def checkpoint(self, step, postfix): | 496 | def checkpoint(self, step, postfix): |