From 83808fe00ac891ad2f625388d144c318b2cb5bfe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 21:53:07 +0100 Subject: WIP: Modularization ("free(): invalid pointer" my ass) --- train_ti.py | 74 +++++++++++++------------------------------------------------ 1 file changed, 15 insertions(+), 59 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 8631892..deed84c 100644 --- a/train_ti.py +++ b/train_ti.py @@ -19,10 +19,11 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models +from trainer.base import Checkpointer +from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models from training.optimization import get_scheduler from training.lr import LRFinder -from training.util import CheckpointerBase, EMAModel, save_args +from training.util import EMAModel, save_args from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -480,38 +481,20 @@ def parse_args(): return args -class Checkpointer(CheckpointerBase): +class TextualInversionCheckpointer(Checkpointer): def __init__( self, - weight_dtype: torch.dtype, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, ema_embeddings: EMAModel, - scheduler, - placeholder_tokens, - placeholder_token_ids, *args, - **kwargs + **kwargs, ): super().__init__(*args, **kwargs) - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder self.ema_embeddings = ema_embeddings - self.scheduler = scheduler - self.placeholder_tokens = placeholder_tokens - self.placeholder_token_ids = placeholder_token_ids @torch.no_grad() def checkpoint(self, step, postfix): - print("Saving checkpoint for step %d..." % step) + print(f"Saving checkpoint for step {step}...") checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) @@ -519,7 +502,8 @@ class Checkpointer(CheckpointerBase): text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if self.ema_embeddings is not None else nullcontext() with ema_context: for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): @@ -528,42 +512,14 @@ class Checkpointer(CheckpointerBase): checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") ) - del text_encoder - - @torch.no_grad() + @torch.inference_mode() def save_samples(self, step): - unet = self.accelerator.unwrap_model(self.unet) - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() + self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) if self.ema_embeddings is not None else nullcontext() with ema_context: - orig_unet_dtype = unet.dtype - orig_text_encoder_dtype = text_encoder.dtype - - unet.to(dtype=self.weight_dtype) - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step) - - unet.to(dtype=orig_unet_dtype) - text_encoder.to(dtype=orig_text_encoder_dtype) - - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() + super().save_samples(step) def main(): @@ -806,8 +762,8 @@ def main(): args.seed, ) - checkpointer = Checkpointer( - weight_dtype=weight_dtype, + checkpointer = TextualInversionCheckpointer( + dtype=weight_dtype, train_dataloader=train_dataloader, val_dataloader=val_dataloader, accelerator=accelerator, @@ -816,7 +772,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, ema_embeddings=ema_embeddings, - scheduler=sample_scheduler, + sample_scheduler=sample_scheduler, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, output_dir=output_dir, -- cgit v1.2.3-54-g00ecf