From 36440e48ce279872d6e736bcb1bf57d13da73a11 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 17:09:01 +0100 Subject: Moved multi-TI code from Dreambooth to TI script --- train_ti.py | 221 +++++++++++++++++++++++++++++++----------------------------- 1 file changed, 114 insertions(+), 107 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 7aecdef..adba8d4 100644 --- a/train_ti.py +++ b/train_ti.py @@ -51,6 +51,7 @@ def parse_args(): parser.add_argument( "--train_data_template", type=str, + nargs='*', default="template", ) parser.add_argument( @@ -468,11 +469,17 @@ def parse_args(): args.num_vectors = 1 if isinstance(args.num_vectors, int): - args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) + args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.train_data_template, str): + args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) + + if len(args.placeholder_tokens) != len(args.train_data_template): + raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + if isinstance(args.collection, str): args.collection = [args.collection] @@ -507,6 +514,8 @@ def main(): set_seed(args.seed) + seed_generator = torch.Generator().manual_seed(args.seed) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -531,19 +540,6 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( - tokenizer=tokenizer, - embeddings=embeddings, - placeholder_tokens=args.placeholder_tokens, - initializer_tokens=args.initializer_tokens, - num_vectors=args.num_vectors - ) - - if len(placeholder_token_ids) != 0: - initializer_token_id_lens = [len(id) for id in initializer_token_ids] - placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) - print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -566,43 +562,6 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - datamodule = VlpnDataModule( - data_file=args.train_data_file, - batch_size=args.train_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, - num_class_images=args.num_class_images, - size=args.resolution, - num_buckets=args.num_buckets, - progressive_buckets=args.progressive_buckets, - bucket_step_size=args.bucket_step_size, - bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, - shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, - valid_set_size=args.valid_set_size, - train_set_pad=args.train_set_pad, - valid_set_pad=args.valid_set_pad, - seed=args.seed, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() - - if args.num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - sample_scheduler, - datamodule.train_dataset, - args.sample_batch_size, - args.sample_image_size, - args.sample_steps - ) - trainer = partial( train, accelerator=accelerator, @@ -615,63 +574,111 @@ def main(): callbacks_fn=textual_inversion_strategy ) - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( + range(len(args.placeholder_tokens)), + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.train_data_template + ): + cur_subdir = output_dir.joinpath(placeholder_token) + cur_subdir.mkdir(parents=True, exist_ok=True) + + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=[placeholder_token], + initializer_tokens=[initializer_token], + num_vectors=[num_vectors] + ) - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - train_epochs=args.num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - ) - - trainer( - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - output_dir=output_dir, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - learning_rate=args.learning_rate, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, - ) + print( + f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") + + args.seed = seed_generator.seed() + + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=data_template, + valid_set_size=args.valid_set_size, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, + seed=args.seed, + filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), + dtype=weight_dtype + ) + datamodule.setup() + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + trainer( + project="textual_inversion", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + output_dir=cur_subdir, + placeholder_tokens=[placeholder_token], + placeholder_token_ids=placeholder_token_ids, + learning_rate=args.learning_rate, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + embeddings.persist() if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf