From 36440e48ce279872d6e736bcb1bf57d13da73a11 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 17:09:01 +0100 Subject: Moved multi-TI code from Dreambooth to TI script --- data/csv.py | 2 +- train_dreambooth.py | 135 +----------------------------- train_ti.py | 221 +++++++++++++++++++++++++------------------------ training/functional.py | 17 +++- 4 files changed, 131 insertions(+), 244 deletions(-) diff --git a/data/csv.py b/data/csv.py index 85b98f8..6857b6f 100644 --- a/data/csv.py +++ b/data/csv.py @@ -282,7 +282,7 @@ class VlpnDataModule(): collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: - data_train, data_val = items, items[:1] + data_train, data_val = items, [] else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) diff --git a/train_dreambooth.py b/train_dreambooth.py index 1dc41b1..6511f9b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -199,23 +199,6 @@ def parse_args(): type=int, default=100 ) - parser.add_argument( - "--ti_data_template", - type=str, - nargs='*', - default=[], - ) - parser.add_argument( - "--ti_num_train_epochs", - type=int, - default=10 - ) - parser.add_argument( - "--ti_batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader." - ) parser.add_argument( "--max_train_steps", type=int, @@ -244,12 +227,6 @@ def parse_args(): default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) - parser.add_argument( - "--ti_learning_rate", - type=float, - default=1e-2, - help="Initial learning rate (after the potential warmup period) to use.", - ) parser.add_argument( "--scale_lr", action="store_true", @@ -482,12 +459,6 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if isinstance(args.ti_data_template, str): - args.ti_data_template = [args.ti_data_template] - - if len(args.ti_data_template) == 0: - raise ValueError("You must specify --ti_data_template") - if isinstance(args.collection, str): args.collection = [args.collection] @@ -521,8 +492,6 @@ def main(): set_seed(args.seed) - seed_generator = torch.Generator().manual_seed(args.seed) - save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -583,107 +552,6 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) - # Initial TI - - print("Phase 1: Textual Inversion") - - cur_dir = output_dir.joinpath("1-ti") - cur_dir.mkdir(parents=True, exist_ok=True) - - for i, placeholder_token, initializer_token, num_vectors, data_template in zip( - range(len(args.placeholder_tokens)), - args.placeholder_tokens, - args.initializer_tokens, - args.num_vectors, - args.ti_data_template - ): - cur_subdir = cur_dir.joinpath(placeholder_token) - cur_subdir.mkdir(parents=True, exist_ok=True) - - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( - tokenizer=tokenizer, - embeddings=embeddings, - placeholder_tokens=[placeholder_token], - initializer_tokens=[initializer_token], - num_vectors=[num_vectors] - ) - - print( - f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") - - args.seed = seed_generator.seed() - - datamodule = VlpnDataModule( - data_file=args.train_data_file, - batch_size=args.ti_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, - num_class_images=args.num_class_images, - size=args.resolution, - shuffle=not args.no_tag_shuffle, - template_key=data_template, - valid_set_size=1, - train_set_pad=args.train_set_pad, - valid_set_pad=args.valid_set_pad, - seed=args.seed, - filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() - - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=args.ti_learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=0.0, - eps=args.adam_epsilon, - ) - - lr_scheduler = get_scheduler( - "one_cycle", - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - train_epochs=args.ti_num_train_epochs, - ) - - trainer( - callbacks_fn=textual_inversion_strategy, - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - seed=args.seed, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=args.ti_num_train_epochs, - sample_frequency=args.ti_num_train_epochs // 5, - checkpoint_frequency=9999999, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - output_dir=cur_subdir, - placeholder_tokens=[placeholder_token], - placeholder_token_ids=placeholder_token_ids, - learning_rate=args.ti_learning_rate, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=True, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, - ) - - embeddings.persist() - - # Dreambooth - - print("Phase 2: Dreambooth") - - cur_dir = output_dir.joinpath("2-db") - cur_dir.mkdir(parents=True, exist_ok=True) - - args.seed = seed_generator.seed() - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -746,12 +614,13 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, + prepare_unet=True, num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, - output_dir=cur_dir, + output_dir=output_dir, train_text_encoder_epochs=args.train_text_encoder_epochs, max_grad_norm=args.max_grad_norm, use_ema=args.use_ema, diff --git a/train_ti.py b/train_ti.py index 7aecdef..adba8d4 100644 --- a/train_ti.py +++ b/train_ti.py @@ -51,6 +51,7 @@ def parse_args(): parser.add_argument( "--train_data_template", type=str, + nargs='*', default="template", ) parser.add_argument( @@ -468,11 +469,17 @@ def parse_args(): args.num_vectors = 1 if isinstance(args.num_vectors, int): - args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) + args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.train_data_template, str): + args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) + + if len(args.placeholder_tokens) != len(args.train_data_template): + raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + if isinstance(args.collection, str): args.collection = [args.collection] @@ -507,6 +514,8 @@ def main(): set_seed(args.seed) + seed_generator = torch.Generator().manual_seed(args.seed) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -531,19 +540,6 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( - tokenizer=tokenizer, - embeddings=embeddings, - placeholder_tokens=args.placeholder_tokens, - initializer_tokens=args.initializer_tokens, - num_vectors=args.num_vectors - ) - - if len(placeholder_token_ids) != 0: - initializer_token_id_lens = [len(id) for id in initializer_token_ids] - placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) - print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -566,43 +562,6 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - datamodule = VlpnDataModule( - data_file=args.train_data_file, - batch_size=args.train_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, - num_class_images=args.num_class_images, - size=args.resolution, - num_buckets=args.num_buckets, - progressive_buckets=args.progressive_buckets, - bucket_step_size=args.bucket_step_size, - bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, - shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, - valid_set_size=args.valid_set_size, - train_set_pad=args.train_set_pad, - valid_set_pad=args.valid_set_pad, - seed=args.seed, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() - - if args.num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - sample_scheduler, - datamodule.train_dataset, - args.sample_batch_size, - args.sample_image_size, - args.sample_steps - ) - trainer = partial( train, accelerator=accelerator, @@ -615,63 +574,111 @@ def main(): callbacks_fn=textual_inversion_strategy ) - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( + range(len(args.placeholder_tokens)), + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.train_data_template + ): + cur_subdir = output_dir.joinpath(placeholder_token) + cur_subdir.mkdir(parents=True, exist_ok=True) + + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=[placeholder_token], + initializer_tokens=[initializer_token], + num_vectors=[num_vectors] + ) - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - train_epochs=args.num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - ) - - trainer( - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - output_dir=output_dir, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - learning_rate=args.learning_rate, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, - ) + print( + f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") + + args.seed = seed_generator.seed() + + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=data_template, + valid_set_size=args.valid_set_size, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, + seed=args.seed, + filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), + dtype=weight_dtype + ) + datamodule.setup() + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + trainer( + project="textual_inversion", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + output_dir=cur_subdir, + placeholder_tokens=[placeholder_token], + placeholder_token_ids=placeholder_token_ids, + learning_rate=args.learning_rate, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + embeddings.persist() if __name__ == "__main__": diff --git a/training/functional.py b/training/functional.py index b6b5d87..1548784 100644 --- a/training/functional.py +++ b/training/functional.py @@ -515,6 +515,7 @@ def train( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, callbacks_fn: Callable[..., TrainingCallbacks], + prepare_unet: bool = False, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, @@ -523,9 +524,19 @@ def train( prior_loss_weight: float = 1.0, **kwargs, ): - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) + prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] + + if prepare_unet: + prep.append(unet) + + prep = accelerator.prepare(*prep) + + if prepare_unet: + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep + else: + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep + + unet.to(accelerator.device, dtype=dtype) vae.to(accelerator.device, dtype=dtype) -- cgit v1.2.3-54-g00ecf