summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 22:42:44 +0100
commitf00877a13bce50b02cfc3790f2d18a325e9ff95b (patch)
treeebbda04024081e9c3c00400fae98124f3db2cc9c /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2
textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index a4e2dde..78c1b5c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -11,20 +11,16 @@ import torch.utils.checkpoint
11from accelerate import Accelerator 11from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from diffusers import AutoencoderKL, UNet2DConditionModel
15import matplotlib.pyplot as plt 14import matplotlib.pyplot as plt
16from transformers import CLIPTextModel
17from slugify import slugify 15from slugify import slugify
18 16
19from util import load_config, load_embeddings_from_dir 17from util import load_config, load_embeddings_from_dir
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from data.csv import VlpnDataModule, VlpnDataItem 18from data.csv import VlpnDataModule, VlpnDataItem
22from trainer.base import Checkpointer 19from trainer_old.base import Checkpointer
23from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 20from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
24from training.optimization import get_scheduler 21from training.optimization import get_scheduler
25from training.lr import LRFinder 22from training.lr import LRFinder
26from training.util import EMAModel, save_args 23from training.util import EMAModel, save_args
27from models.clip.tokenizer import MultiCLIPTokenizer
28 24
29logger = get_logger(__name__) 25logger = get_logger(__name__)
30 26
@@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer):
485 def __init__( 481 def __init__(
486 self, 482 self,
487 ema_embeddings: EMAModel, 483 ema_embeddings: EMAModel,
484 placeholder_tokens: list[str],
485 placeholder_token_ids: list[list[int]],
488 *args, 486 *args,
489 **kwargs, 487 **kwargs,
490 ): 488 ):
491 super().__init__(*args, **kwargs) 489 super().__init__(*args, **kwargs)
492 490
493 self.ema_embeddings = ema_embeddings 491 self.ema_embeddings = ema_embeddings
492 self.placeholder_tokens = placeholder_tokens
493 self.placeholder_token_ids = placeholder_token_ids
494 494
495 @torch.no_grad() 495 @torch.no_grad()
496 def checkpoint(self, step, postfix): 496 def checkpoint(self, step, postfix):