From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 07:20:45 +0100 Subject: Update --- train_ti.py | 113 ++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 64 insertions(+), 49 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index e7aeb23..0891c49 100644 --- a/train_ti.py +++ b/train_ti.py @@ -14,7 +14,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -78,6 +78,10 @@ def parse_args(): nargs='*', help="Number of vectors per embedding." ) + parser.add_argument( + "--simultaneous", + action="store_true", + ) parser.add_argument( "--num_class_images", type=int, @@ -474,11 +478,12 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if isinstance(args.train_data_template, str): - args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) + if not args.simultaneous: + if isinstance(args.train_data_template, str): + args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) - if len(args.placeholder_tokens) != len(args.train_data_template): - raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + if len(args.placeholder_tokens) != len(args.train_data_template): + raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -560,6 +565,8 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + checkpoint_output_dir = output_dir.joinpath("checkpoints") + trainer = partial( train, accelerator=accelerator, @@ -569,30 +576,50 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, - callbacks_fn=textual_inversion_strategy - ) - - checkpoint_output_dir = output_dir.joinpath("checkpoints") - - for i, placeholder_token, initializer_token, num_vectors, data_template in zip( - range(len(args.placeholder_tokens)), - args.placeholder_tokens, - args.initializer_tokens, - args.num_vectors, - args.train_data_template - ): - sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + strategy=textual_inversion_strategy, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + checkpoint_output_dir=checkpoint_output_dir, + learning_rate=args.learning_rate, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): + if len(placeholder_tokens) == 1: + sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") + else: + sample_output_dir = output_dir.joinpath("samples") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, - placeholder_tokens=[placeholder_token], - initializer_tokens=[initializer_token], - num_vectors=[num_vectors] + placeholder_tokens=placeholder_tokens, + initializer_tokens=initializer_tokens, + num_vectors=num_vectors ) - print( - f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") + stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) + + print(f"{i + 1}: {stats})") datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -612,7 +639,7 @@ def main(): train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, - filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), + filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -647,36 +674,24 @@ def main(): val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, sample_output_dir=sample_output_dir, - checkpoint_output_dir=checkpoint_output_dir, - placeholder_tokens=[placeholder_token], + placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, - learning_rate=args.learning_rate, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, ) - embeddings.persist() + if args.simultaneous: + run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) + else: + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( + range(len(args.placeholder_tokens)), + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.train_data_template + ): + run(i, [placeholder_token], [initializer_token], [num_vectors], data_template) + embeddings.persist() if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf