From f00877a13bce50b02cfc3790f2d18a325e9ff95b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:42:44 +0100 Subject: Update --- train_ti.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index a4e2dde..78c1b5c 100644 --- a/train_ti.py +++ b/train_ti.py @@ -11,20 +11,16 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, UNet2DConditionModel import matplotlib.pyplot as plt -from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from trainer.base import Checkpointer +from trainer_old.base import Checkpointer from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models from training.optimization import get_scheduler from training.lr import LRFinder from training.util import EMAModel, save_args -from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -485,12 +481,16 @@ class TextualInversionCheckpointer(Checkpointer): def __init__( self, ema_embeddings: EMAModel, + placeholder_tokens: list[str], + placeholder_token_ids: list[list[int]], *args, **kwargs, ): super().__init__(*args, **kwargs) self.ema_embeddings = ema_embeddings + self.placeholder_tokens = placeholder_tokens + self.placeholder_token_ids = placeholder_token_ids @torch.no_grad() def checkpoint(self, step, postfix): -- cgit v1.2.3-54-g00ecf