From 34648b763fa60e3161a5b5f1243ed1b5c3b0188e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 10:12:04 +0100 Subject: Added functional TI strategy --- train_ti.py | 108 +++++++++++++++++------------------------------------------- 1 file changed, 30 insertions(+), 78 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 97e4e72..2fd325b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -17,7 +17,8 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem from trainer_old.base import Checkpointer -from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, generate_class_images, add_placeholder_tokens, get_models +from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.lr import LRFinder from training.util import EMAModel, save_args @@ -386,6 +387,11 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--use_emb_decay", + action="store_true", + help="Whether to use embedding decay." + ) parser.add_argument( "--emb_decay_target", default=0.4, @@ -591,14 +597,6 @@ def main(): else: ema_embeddings = None - vae.requires_grad_(False) - unet.requires_grad_(False) - - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -719,73 +717,36 @@ def main(): seed=args.seed, ) - def on_prepare(): - text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) - - if args.gradient_checkpointing: - unet.train() - - @contextmanager - def on_train(epoch: int): - try: - tokenizer.train() - yield - finally: - pass - - @contextmanager - def on_eval(): - try: - tokenizer.eval() - - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() - - with ema_context: - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(lr: float): - if args.emb_decay_factor != 0: - text_encoder.text_model.embeddings.normalize( - args.emb_decay_target, - min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) - ) - - if args.use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(): - if args.use_ema: - return {"ema_decay": ema_embeddings.decay} - return {} - - checkpointer = TextualInversionCheckpointer( - dtype=weight_dtype, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, + strategy = textual_inversion_strategy( accelerator=accelerator, - vae=vae, unet=unet, - tokenizer=tokenizer, text_encoder=text_encoder, - ema_embeddings=ema_embeddings, + tokenizer=tokenizer, + vae=vae, sample_scheduler=sample_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + output_dir=output_dir, + seed=args.seed, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, - output_dir=output_dir, - sample_steps=args.sample_steps, - sample_image_size=args.sample_image_size, + learning_rate=args.learning_rate, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, ) - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - if args.find_lr: lr_finder = LRFinder( accelerator=accelerator, @@ -793,10 +754,7 @@ def main(): model=text_encoder, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - loss_step=loss_step_, - on_train=on_train, - on_eval=on_eval, - on_after_optimize=on_after_optimize, + **strategy, ) lr_finder.run(num_epochs=100, end_lr=1e3) @@ -811,13 +769,7 @@ def main(): checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, prior_loss_weight=args.prior_loss_weight, - on_prepare=on_prepare, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval, - on_sample=checkpointer.save_samples, - on_checkpoint=checkpointer.checkpoint, + **strategy, ) -- cgit v1.2.3-54-g00ecf