From 5b9a3de142e7a645573b4f4a8c1ce9c59746ab08 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 09:25:30 +0100 Subject: Added functional trainer --- train_ti.py | 49 +++++++++++++++++++++++-------------------------- 1 file changed, 23 insertions(+), 26 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 78c1b5c..97e4e72 100644 --- a/train_ti.py +++ b/train_ti.py @@ -17,7 +17,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem from trainer_old.base import Checkpointer -from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models from training.optimization import get_scheduler from training.lr import LRFinder from training.util import EMAModel, save_args @@ -703,17 +703,27 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=weight_dtype) - if args.use_ema: ema_embeddings.to(accelerator.device) - if args.gradient_checkpointing: - unet.train() + trainer = partial( + train, + accelerator=accelerator, + vae=vae, + unet=unet, + text_encoder=text_encoder, + noise_scheduler=noise_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + seed=args.seed, + ) + + def on_prepare(): + text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) + + if args.gradient_checkpointing: + unet.train() @contextmanager def on_train(epoch: int): @@ -752,16 +762,6 @@ def main(): return {"ema_decay": ema_embeddings.decay} return {} - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.prior_loss_weight, - args.seed, - ) - checkpointer = TextualInversionCheckpointer( dtype=weight_dtype, train_dataloader=train_dataloader, @@ -803,18 +803,15 @@ def main(): plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: - train_loop( - accelerator=accelerator, + trainer( optimizer=optimizer, lr_scheduler=lr_scheduler, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, + num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, - num_epochs=args.num_train_epochs, + prior_loss_weight=args.prior_loss_weight, + on_prepare=on_prepare, on_log=on_log, on_train=on_train, on_after_optimize=on_after_optimize, -- cgit v1.2.3-54-g00ecf