From 632ce00b54ffeacfc18f44f10827f167ab3ac37c Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 21:06:16 +0100 Subject: Restored functional trainer --- train_ti.py | 82 ++++++++++++++++--------------------------------------------- 1 file changed, 21 insertions(+), 61 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 4bac736..77dec12 100644 --- a/train_ti.py +++ b/train_ti.py @@ -10,15 +10,13 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -import matplotlib.pyplot as plt from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem -from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, generate_class_images, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler -from training.lr import LRFinder from training.util import save_args logger = get_logger(__name__) @@ -644,23 +642,33 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=weight_dtype) - - callbacks = textual_inversion_strategy( + trainer = partial( + train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, - tokenizer=tokenizer, vae=vae, - sample_scheduler=sample_scheduler, + noise_scheduler=noise_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - output_dir=output_dir, + dtype=weight_dtype, seed=args.seed, + callbacks_fn=textual_inversion_strategy + ) + + trainer( + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + output_dir=output_dir, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, learning_rate=args.learning_rate, @@ -679,54 +687,6 @@ def main(): sample_image_size=args.sample_image_size, ) - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() - - callbacks.on_prepare() - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.num_class_images != 0, - args.prior_loss_weight, - args.seed, - ) - - if args.find_lr: - lr_finder = LRFinder( - accelerator=accelerator, - optimizer=optimizer, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - callbacks=callbacks, - ) - lr_finder.run(num_epochs=100, end_lr=1e3) - - plt.savefig(output_dir.joinpath("lr.png"), dpi=300) - plt.close() - else: - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - callbacks=callbacks, - ) - - accelerator.end_training() - if __name__ == "__main__": main() -- cgit v1.2.3-54-g00ecf