diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
| commit | f00877a13bce50b02cfc3790f2d18a325e9ff95b (patch) | |
| tree | ebbda04024081e9c3c00400fae98124f3db2cc9c /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2 textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py index a4e2dde..78c1b5c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -11,20 +11,16 @@ import torch.utils.checkpoint | |||
| 11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from diffusers import AutoencoderKL, UNet2DConditionModel | ||
| 15 | import matplotlib.pyplot as plt | 14 | import matplotlib.pyplot as plt |
| 16 | from transformers import CLIPTextModel | ||
| 17 | from slugify import slugify | 15 | from slugify import slugify |
| 18 | 16 | ||
| 19 | from util import load_config, load_embeddings_from_dir | 17 | from util import load_config, load_embeddings_from_dir |
| 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 21 | from data.csv import VlpnDataModule, VlpnDataItem | 18 | from data.csv import VlpnDataModule, VlpnDataItem |
| 22 | from trainer.base import Checkpointer | 19 | from trainer_old.base import Checkpointer |
| 23 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models | 20 | from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models |
| 24 | from training.optimization import get_scheduler | 21 | from training.optimization import get_scheduler |
| 25 | from training.lr import LRFinder | 22 | from training.lr import LRFinder |
| 26 | from training.util import EMAModel, save_args | 23 | from training.util import EMAModel, save_args |
| 27 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 28 | 24 | ||
| 29 | logger = get_logger(__name__) | 25 | logger = get_logger(__name__) |
| 30 | 26 | ||
| @@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 485 | def __init__( | 481 | def __init__( |
| 486 | self, | 482 | self, |
| 487 | ema_embeddings: EMAModel, | 483 | ema_embeddings: EMAModel, |
| 484 | placeholder_tokens: list[str], | ||
| 485 | placeholder_token_ids: list[list[int]], | ||
| 488 | *args, | 486 | *args, |
| 489 | **kwargs, | 487 | **kwargs, |
| 490 | ): | 488 | ): |
| 491 | super().__init__(*args, **kwargs) | 489 | super().__init__(*args, **kwargs) |
| 492 | 490 | ||
| 493 | self.ema_embeddings = ema_embeddings | 491 | self.ema_embeddings = ema_embeddings |
| 492 | self.placeholder_tokens = placeholder_tokens | ||
| 493 | self.placeholder_token_ids = placeholder_token_ids | ||
| 494 | 494 | ||
| 495 | @torch.no_grad() | 495 | @torch.no_grad() |
| 496 | def checkpoint(self, step, postfix): | 496 | def checkpoint(self, step, postfix): |
