From 3c6ccadd3c12c54a1fa2280bce505a2dd511958a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 07:27:45 +0100 Subject: Implemented extended Dreambooth training --- train_ti.py | 62 ++++++++++++++++++++++--------------------------------------- 1 file changed, 22 insertions(+), 40 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 2497519..48a2333 100644 --- a/train_ti.py +++ b/train_ti.py @@ -13,7 +13,7 @@ from accelerate.utils import LoggerType, set_seed from slugify import slugify from util import load_config, load_embeddings_from_dir -from data.csv import VlpnDataModule, VlpnDataItem +from data.csv import VlpnDataModule, keyword_filter from training.functional import train, generate_class_images, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler @@ -446,15 +446,15 @@ def parse_args(): if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] - if len(args.placeholder_tokens) == 0: - args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] - if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) if len(args.initializer_tokens) == 0: raise ValueError("You must specify --initializer_tokens") + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") @@ -544,9 +544,6 @@ def main(): args.train_batch_size * accelerator.num_processes ) - if args.find_lr: - args.learning_rate = 1e-5 - if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -563,19 +560,6 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - def keyword_filter(item: VlpnDataItem): - cond1 = any( - keyword in part - for keyword in args.placeholder_tokens - for part in item.prompt - ) - cond3 = args.collection is None or args.collection in item.collection - cond4 = args.exclude_collections is None or not any( - collection in item.collection - for collection in args.exclude_collections - ) - return cond1 and cond3 and cond4 - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -593,7 +577,7 @@ def main(): valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, - filter=keyword_filter, + filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -622,8 +606,6 @@ def main(): text_encoder=text_encoder, vae=vae, noise_scheduler=noise_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, dtype=weight_dtype, seed=args.seed, callbacks_fn=textual_inversion_strategy @@ -638,25 +620,25 @@ def main(): amsgrad=args.adam_amsgrad, ) - if args.find_lr: - lr_scheduler = None - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - train_epochs=args.num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - ) + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) trainer( + project="textual_inversion", + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, -- cgit v1.2.3-54-g00ecf