From 36440e48ce279872d6e736bcb1bf57d13da73a11 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 17:09:01 +0100 Subject: Moved multi-TI code from Dreambooth to TI script --- train_dreambooth.py | 135 +--------------------------------------------------- 1 file changed, 2 insertions(+), 133 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 1dc41b1..6511f9b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -199,23 +199,6 @@ def parse_args(): type=int, default=100 ) - parser.add_argument( - "--ti_data_template", - type=str, - nargs='*', - default=[], - ) - parser.add_argument( - "--ti_num_train_epochs", - type=int, - default=10 - ) - parser.add_argument( - "--ti_batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader." - ) parser.add_argument( "--max_train_steps", type=int, @@ -244,12 +227,6 @@ def parse_args(): default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) - parser.add_argument( - "--ti_learning_rate", - type=float, - default=1e-2, - help="Initial learning rate (after the potential warmup period) to use.", - ) parser.add_argument( "--scale_lr", action="store_true", @@ -482,12 +459,6 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if isinstance(args.ti_data_template, str): - args.ti_data_template = [args.ti_data_template] - - if len(args.ti_data_template) == 0: - raise ValueError("You must specify --ti_data_template") - if isinstance(args.collection, str): args.collection = [args.collection] @@ -521,8 +492,6 @@ def main(): set_seed(args.seed) - seed_generator = torch.Generator().manual_seed(args.seed) - save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -583,107 +552,6 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) - # Initial TI - - print("Phase 1: Textual Inversion") - - cur_dir = output_dir.joinpath("1-ti") - cur_dir.mkdir(parents=True, exist_ok=True) - - for i, placeholder_token, initializer_token, num_vectors, data_template in zip( - range(len(args.placeholder_tokens)), - args.placeholder_tokens, - args.initializer_tokens, - args.num_vectors, - args.ti_data_template - ): - cur_subdir = cur_dir.joinpath(placeholder_token) - cur_subdir.mkdir(parents=True, exist_ok=True) - - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( - tokenizer=tokenizer, - embeddings=embeddings, - placeholder_tokens=[placeholder_token], - initializer_tokens=[initializer_token], - num_vectors=[num_vectors] - ) - - print( - f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") - - args.seed = seed_generator.seed() - - datamodule = VlpnDataModule( - data_file=args.train_data_file, - batch_size=args.ti_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, - num_class_images=args.num_class_images, - size=args.resolution, - shuffle=not args.no_tag_shuffle, - template_key=data_template, - valid_set_size=1, - train_set_pad=args.train_set_pad, - valid_set_pad=args.valid_set_pad, - seed=args.seed, - filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() - - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=args.ti_learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=0.0, - eps=args.adam_epsilon, - ) - - lr_scheduler = get_scheduler( - "one_cycle", - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - train_epochs=args.ti_num_train_epochs, - ) - - trainer( - callbacks_fn=textual_inversion_strategy, - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - seed=args.seed, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=args.ti_num_train_epochs, - sample_frequency=args.ti_num_train_epochs // 5, - checkpoint_frequency=9999999, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - output_dir=cur_subdir, - placeholder_tokens=[placeholder_token], - placeholder_token_ids=placeholder_token_ids, - learning_rate=args.ti_learning_rate, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=True, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, - ) - - embeddings.persist() - - # Dreambooth - - print("Phase 2: Dreambooth") - - cur_dir = output_dir.joinpath("2-db") - cur_dir.mkdir(parents=True, exist_ok=True) - - args.seed = seed_generator.seed() - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -746,12 +614,13 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, + prepare_unet=True, num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, - output_dir=cur_dir, + output_dir=output_dir, train_text_encoder_epochs=args.train_text_encoder_epochs, max_grad_norm=args.max_grad_norm, use_ema=args.use_ema, -- cgit v1.2.3-54-g00ecf