From eb6a92abda5893c975437026cdaf0ce0bfefe2a4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 10:34:12 +0200 Subject: Update --- train_lora.py | 66 ++++++++++++++++++++++++++------------- train_ti.py | 68 +++++++++++++++++++++++++---------------- training/functional.py | 4 +-- training/strategy/dreambooth.py | 2 +- training/strategy/lora.py | 2 +- training/strategy/ti.py | 2 +- 6 files changed, 89 insertions(+), 55 deletions(-) diff --git a/train_lora.py b/train_lora.py index e81742a..4bbc64e 100644 --- a/train_lora.py +++ b/train_lora.py @@ -198,6 +198,11 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--train_dir_embeddings", + action="store_true", + help="Train embeddings loaded from embeddings directory.", + ) parser.add_argument( "--collection", type=str, @@ -439,6 +444,12 @@ def parse_args(): default=1, help="How often to save a checkpoint and sample image", ) + parser.add_argument( + "--sample_num", + type=int, + default=None, + help="How often to save a checkpoint and sample image (in number of samples)", + ) parser.add_argument( "--sample_image_size", type=int, @@ -681,27 +692,36 @@ def main(): embeddings.persist() print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + placeholder_token_ids = [] + if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( - tokenizer=tokenizer, - embeddings=embeddings, - placeholder_tokens=args.placeholder_tokens, - initializer_tokens=args.initializer_tokens, - num_vectors=args.num_vectors, - initializer_noise=args.initializer_noise, - ) - stats = list(zip( - args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids - )) - print(f"Training embeddings: {stats}") + if args.train_dir_embeddings: + args.placeholder_tokens = added_tokens + placeholder_token_ids = added_ids + print("Training embeddings from embeddings dir") + else: + embeddings.persist() + + if not args.train_dir_embeddings: + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors, + initializer_noise=args.initializer_noise, + ) + stats = list(zip( + args.placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids + )) + print(f"Training embeddings: {stats}") if args.scale_lr: args.learning_rate_unet = ( @@ -897,6 +917,8 @@ def main(): args.num_train_steps / len(lora_datamodule.train_dataset) ) * args.gradient_accumulation_steps lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) + if args.sample_num is not None: + lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) params_to_optimize = [] group_labels = [] @@ -930,15 +952,6 @@ def main(): ] group_labels += ["unet", "text"] - lora_optimizer = create_optimizer(params_to_optimize) - - lora_lr_scheduler = create_lr_scheduler( - gradient_accumulation_steps=args.gradient_accumulation_steps, - optimizer=lora_optimizer, - num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), - train_epochs=num_train_epochs, - ) - training_iter = 0 while True: @@ -952,6 +965,15 @@ def main(): print(f"============ LoRA cycle {training_iter} ============") print("") + lora_optimizer = create_optimizer(params_to_optimize) + + lora_lr_scheduler = create_lr_scheduler( + gradient_accumulation_steps=args.gradient_accumulation_steps, + optimizer=lora_optimizer, + num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), + train_epochs=num_train_epochs, + ) + lora_project = f"lora_{training_iter}" lora_checkpoint_output_dir = output_dir / lora_project / "model" lora_sample_output_dir = output_dir / lora_project / "samples" diff --git a/train_ti.py b/train_ti.py index ebac302..eb08bda 100644 --- a/train_ti.py +++ b/train_ti.py @@ -151,6 +151,11 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--train_dir_embeddings", + action="store_true", + help="Train embeddings loaded from embeddings directory.", + ) parser.add_argument( "--collection", type=str, @@ -403,6 +408,12 @@ def parse_args(): action='store_true', help="If checkpoints are saved on maximum accuracy", ) + parser.add_argument( + "--sample_num", + type=int, + default=None, + help="How often to save a checkpoint and sample image (in number of samples)", + ) parser.add_argument( "--sample_frequency", type=int, @@ -669,9 +680,14 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + if args.train_dir_embeddings: + args.placeholder_tokens = added_tokens + print("Training embeddings from embeddings dir") + else: + embeddings.persist() + if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -852,28 +868,8 @@ def main(): args.num_train_steps / len(datamodule.train_dataset) ) * args.gradient_accumulation_steps sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) - - optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_override_embedding.parameters(), - lr=args.learning_rate, - ) - - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - end_lr=1e3, - train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - mid_point=args.lr_mid_point, - ) + if args.sample_num is not None: + sample_frequency = math.ceil(num_train_epochs / args.sample_num) training_iter = 0 @@ -888,6 +884,28 @@ def main(): print(f"------------ TI cycle {training_iter} ------------") print("") + optimizer = create_optimizer( + text_encoder.text_model.embeddings.token_override_embedding.parameters(), + lr=args.learning_rate, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + end_lr=1e3, + train_epochs=num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + mid_point=args.lr_mid_point, + ) + project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" sample_output_dir = output_dir / project / "samples" checkpoint_output_dir = output_dir / project / "checkpoints" @@ -908,10 +926,6 @@ def main(): placeholder_token_ids=placeholder_token_ids, ) - response = input("Run another cycle? [y/n] ") - continue_training = response.lower().strip() != "n" - training_iter += 1 - if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) else: diff --git a/training/functional.py b/training/functional.py index e14aeea..46d25f6 100644 --- a/training/functional.py +++ b/training/functional.py @@ -644,11 +644,9 @@ def train( min_snr_gamma: int = 5, **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) - kwargs.update(extra) - vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 695174a..42624cd 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -198,7 +198,7 @@ def dreambooth_prepare( text_encoder.text_model.embeddings.requires_grad_(False) - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ae85401..73ec8f2 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -184,7 +184,7 @@ def lora_prepare( text_encoder.text_model.embeddings.token_override_embedding.params.requires_grad_(True) - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler lora_strategy = TrainingStrategy( diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9cdc1bb..363c3f9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -207,7 +207,7 @@ def textual_inversion_prepare( text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} + return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler textual_inversion_strategy = TrainingStrategy( -- cgit v1.2.3-70-g09d2