From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- train_ti.py | 74 +++++++++++++++++++++++++++++++------------------------------ 1 file changed, 38 insertions(+), 36 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 3c9810f..4bac736 100644 --- a/train_ti.py +++ b/train_ti.py @@ -15,11 +15,11 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem -from training.functional import train, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.lr import LRFinder -from training.util import EMAModel, save_args +from training.util import save_args logger = get_logger(__name__) @@ -82,7 +82,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=1, + default=0, help="How many class images to generate." ) parser.add_argument( @@ -398,7 +398,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay_factor", - default=0, + default=1, type=float, help="Embedding decay factor." ) @@ -540,16 +540,6 @@ def main(): placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") - if args.use_ema: - ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - else: - ema_embeddings = None - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -654,23 +644,13 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - if args.use_ema: - ema_embeddings.to(accelerator.device) - - trainer = partial( - train, - accelerator=accelerator, - vae=vae, - unet=unet, - text_encoder=text_encoder, - noise_scheduler=noise_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - dtype=weight_dtype, - seed=args.seed, + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - strategy = textual_inversion_strategy( + vae.to(accelerator.device, dtype=weight_dtype) + + callbacks = textual_inversion_strategy( accelerator=accelerator, unet=unet, text_encoder=text_encoder, @@ -679,7 +659,6 @@ def main(): sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=output_dir, seed=args.seed, placeholder_tokens=args.placeholder_tokens, @@ -700,31 +679,54 @@ def main(): sample_image_size=args.sample_image_size, ) + for model in (unet, text_encoder, vae): + model.requires_grad_(False) + model.eval() + + callbacks.on_prepare() + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + args.num_class_images != 0, + args.prior_loss_weight, + args.seed, + ) + if args.find_lr: lr_finder = LRFinder( accelerator=accelerator, optimizer=optimizer, - model=text_encoder, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - **strategy, + callbacks=callbacks, ) lr_finder.run(num_epochs=100, end_lr=1e3) plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: - trainer( + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, - prior_loss_weight=args.prior_loss_weight, - callbacks=strategy, + callbacks=callbacks, ) + accelerator.end_training() + if __name__ == "__main__": main() -- cgit v1.2.3-54-g00ecf