From 179a45253a5b3712f32bd127f693a6bb810a9c17 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 28 Mar 2023 16:24:22 +0200 Subject: Support num_train_steps arg again --- train_ti.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index e4fd464..7bcc72f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -3,6 +3,7 @@ import datetime import logging from functools import partial from pathlib import Path +import math import torch import torch.utils.checkpoint @@ -207,7 +208,12 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100 + default=None + ) + parser.add_argument( + "--num_train_steps", + type=int, + default=2000 ) parser.add_argument( "--gradient_accumulation_steps", @@ -513,13 +519,13 @@ def parse_args(): if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") - args.alias_tokens += [ - item - for pair in zip(args.placeholder_tokens, args.initializer_tokens) - for item in pair - ] - if args.sequential: + args.alias_tokens += [ + item + for pair in zip(args.placeholder_tokens, args.initializer_tokens) + for item in pair + ] + if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) @@ -607,6 +613,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.scale_lr: @@ -682,7 +689,6 @@ def main(): prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, strategy=textual_inversion_strategy, - num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, @@ -752,6 +758,11 @@ def main(): ) datamodule.setup() + num_train_epochs = args.num_train_epochs + + if num_train_epochs is None: + num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps) + optimizer = create_optimizer( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.learning_rate, @@ -769,7 +780,7 @@ def main(): annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, end_lr=1e3, - train_epochs=args.num_train_epochs, + train_epochs=num_train_epochs, warmup_epochs=args.lr_warmup_epochs, ) @@ -779,6 +790,7 @@ def main(): val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, + num_train_epochs=num_train_epochs, # -- sample_output_dir=sample_output_dir, placeholder_tokens=placeholder_tokens, -- cgit v1.2.3-54-g00ecf