From 179a45253a5b3712f32bd127f693a6bb810a9c17 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 28 Mar 2023 16:24:22 +0200 Subject: Support num_train_steps arg again --- data/csv.py | 8 ++++++-- train_dreambooth.py | 17 +++++++++++------ train_lora.py | 17 +++++++++++------ train_ti.py | 30 +++++++++++++++++++++--------- 4 files changed, 49 insertions(+), 23 deletions(-) diff --git a/data/csv.py b/data/csv.py index 9770bec..c00ea07 100644 --- a/data/csv.py +++ b/data/csv.py @@ -143,7 +143,7 @@ class VlpnDataItem(NamedTuple): def keyword_filter( placeholder_tokens: Optional[list[str]], - collection: Optional[list[str]], + collections: Optional[list[str]], exclude_collections: Optional[list[str]], item: VlpnDataItem ): @@ -152,11 +152,15 @@ def keyword_filter( for keyword in placeholder_tokens for part in item.prompt ) - cond2 = collection is None or collection in item.collection + cond2 = collections is None or any( + collection in item.collection + for collection in collections + ) cond3 = exclude_collections is None or not any( collection in item.collection for collection in exclude_collections ) + return cond1 and cond2 and cond3 diff --git a/train_dreambooth.py b/train_dreambooth.py index 9345797..acb8287 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -4,6 +4,7 @@ import logging import itertools from pathlib import Path from functools import partial +import math import torch import torch.utils.checkpoint @@ -189,13 +190,12 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100 + default=None ) parser.add_argument( - "--max_train_steps", + "--num_train_steps", type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + default=2000 ) parser.add_argument( "--gradient_accumulation_steps", @@ -595,6 +595,11 @@ def main(): ) datamodule.setup() + num_train_epochs = args.num_train_epochs + + if num_train_epochs is None: + num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) + params_to_optimize = (unet.parameters(), ) if args.train_text_encoder_epochs != 0: params_to_optimize += ( @@ -619,7 +624,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e2, - train_epochs=args.num_train_epochs, + train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) @@ -631,7 +636,7 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, + num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, offset_noise_strength=args.offset_noise_strength, diff --git a/train_lora.py b/train_lora.py index 7ecddf0..a9c6e52 100644 --- a/train_lora.py +++ b/train_lora.py @@ -4,6 +4,7 @@ import logging import itertools from pathlib import Path from functools import partial +import math import torch import torch.utils.checkpoint @@ -178,13 +179,12 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100 + default=None ) parser.add_argument( - "--max_train_steps", + "--num_train_steps", type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + default=2000 ) parser.add_argument( "--gradient_accumulation_steps", @@ -627,6 +627,11 @@ def main(): ) datamodule.setup() + num_train_epochs = args.num_train_epochs + + if num_train_epochs is None: + num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) + optimizer = create_optimizer( itertools.chain( unet.parameters(), @@ -647,7 +652,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e2, - train_epochs=args.num_train_epochs, + train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) @@ -659,7 +664,7 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, + num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, offset_noise_strength=args.offset_noise_strength, diff --git a/train_ti.py b/train_ti.py index e4fd464..7bcc72f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -3,6 +3,7 @@ import datetime import logging from functools import partial from pathlib import Path +import math import torch import torch.utils.checkpoint @@ -207,7 +208,12 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100 + default=None + ) + parser.add_argument( + "--num_train_steps", + type=int, + default=2000 ) parser.add_argument( "--gradient_accumulation_steps", @@ -513,13 +519,13 @@ def parse_args(): if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") - args.alias_tokens += [ - item - for pair in zip(args.placeholder_tokens, args.initializer_tokens) - for item in pair - ] - if args.sequential: + args.alias_tokens += [ + item + for pair in zip(args.placeholder_tokens, args.initializer_tokens) + for item in pair + ] + if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) @@ -607,6 +613,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.scale_lr: @@ -682,7 +689,6 @@ def main(): prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, strategy=textual_inversion_strategy, - num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, @@ -752,6 +758,11 @@ def main(): ) datamodule.setup() + num_train_epochs = args.num_train_epochs + + if num_train_epochs is None: + num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) + optimizer = create_optimizer( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.learning_rate, @@ -769,7 +780,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e3, - train_epochs=args.num_train_epochs, + train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) @@ -779,6 +790,7 @@ def main(): val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, + num_train_epochs=num_train_epochs, # -- sample_output_dir=sample_output_dir, placeholder_tokens=placeholder_tokens, -- cgit v1.2.3-70-g09d2