From 6c8cffe28baeafac77d047ff3f8ded9418033e2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 15:52:43 +0100 Subject: More training adjustments --- train_dreambooth.py | 71 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 12 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index a9fbbbd..1dc41b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -54,6 +54,18 @@ def parse_args(): type=str, default="template", ) + parser.add_argument( + "--train_set_pad", + type=int, + default=None, + help="The number to fill train dataset items up to." + ) + parser.add_argument( + "--valid_set_pad", + type=int, + default=None, + help="The number to fill validation dataset items up to." + ) parser.add_argument( "--project", type=str, @@ -187,11 +199,23 @@ def parse_args(): type=int, default=100 ) + parser.add_argument( + "--ti_data_template", + type=str, + nargs='*', + default=[], + ) parser.add_argument( "--ti_num_train_epochs", type=int, default=10 ) + parser.add_argument( + "--ti_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) parser.add_argument( "--max_train_steps", type=int, @@ -458,6 +482,12 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.ti_data_template, str): + args.ti_data_template = [args.ti_data_template] + + if len(args.ti_data_template) == 0: + raise ValueError("You must specify --ti_data_template") + if isinstance(args.collection, str): args.collection = [args.collection] @@ -491,6 +521,8 @@ def main(): set_seed(args.seed) + seed_generator = torch.Generator().manual_seed(args.seed) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -512,6 +544,8 @@ def main(): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") + embeddings.persist() + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") @@ -545,7 +579,6 @@ def main(): vae=vae, noise_scheduler=noise_scheduler, dtype=weight_dtype, - seed=args.seed, with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, ) @@ -557,13 +590,17 @@ def main(): cur_dir = output_dir.joinpath("1-ti") cur_dir.mkdir(parents=True, exist_ok=True) - for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): - print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") - + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( + range(len(args.placeholder_tokens)), + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.ti_data_template + ): cur_subdir = cur_dir.joinpath(placeholder_token) cur_subdir.mkdir(parents=True, exist_ok=True) - placeholder_token_ids, _ = add_placeholder_tokens( + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=[placeholder_token], @@ -571,17 +608,23 @@ def main(): num_vectors=[num_vectors] ) + print( + f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") + + args.seed = seed_generator.seed() + datamodule = VlpnDataModule( data_file=args.train_data_file, - batch_size=args.train_batch_size, + batch_size=args.ti_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, + template_key=data_template, valid_set_size=1, - valid_set_repeat=args.valid_set_repeat, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), dtype=weight_dtype @@ -591,7 +634,9 @@ def main(): optimizer = optimizer_class( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.ti_learning_rate, + betas=(args.adam_beta1, args.adam_beta2), weight_decay=0.0, + eps=args.adam_epsilon, ) lr_scheduler = get_scheduler( @@ -600,7 +645,6 @@ def main(): num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, train_epochs=args.ti_num_train_epochs, - warmup_epochs=args.ti_num_train_epochs // 4, ) trainer( @@ -608,10 +652,11 @@ def main(): project="textual_inversion", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, + seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.ti_num_train_epochs, - sample_frequency=2, + sample_frequency=args.ti_num_train_epochs // 5, checkpoint_frequency=9999999, # -- tokenizer=tokenizer, @@ -637,7 +682,7 @@ def main(): cur_dir = output_dir.joinpath("2-db") cur_dir.mkdir(parents=True, exist_ok=True) - args.seed = (args.seed + 28635) >> 32 + args.seed = seed_generator.seed() datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -654,7 +699,8 @@ def main(): shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, - valid_set_repeat=args.valid_set_repeat, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), dtype=weight_dtype @@ -697,6 +743,7 @@ def main(): project="dreambooth", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, + seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, -- cgit v1.2.3-54-g00ecf