From 83808fe00ac891ad2f625388d144c318b2cb5bfe Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 21:53:07 +0100 Subject: WIP: Modularization ("free(): invalid pointer" my ass) --- train.py | 672 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 672 insertions(+) create mode 100644 train.py (limited to 'train.py') diff --git a/train.py b/train.py new file mode 100644 index 0000000..d8644c4 --- /dev/null +++ b/train.py @@ -0,0 +1,672 @@ +import argparse +import datetime +import logging +from pathlib import Path + +import torch +import torch.utils.checkpoint + +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import LoggerType, set_seed +from slugify import slugify + +from data.csv import VlpnDataModule, VlpnDataItem +from util import load_config, load_embeddings_from_dir + +from trainer.ti import TextualInversionTrainingStrategy +from trainer.base import Trainer +from training.optimization import get_scheduler +from training.util import save_args, generate_class_images, add_placeholder_tokens, get_models + +logger = get_logger(__name__) + + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.benchmark = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="A CSV file containing the training data." + ) + parser.add_argument( + "--train_data_template", + type=str, + default="template", + ) + parser.add_argument( + "--project", + type=str, + default=None, + help="The name of the current project.", + ) + parser.add_argument( + "--placeholder_tokens", + type=str, + nargs='*', + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_tokens", + type=str, + nargs='*', + help="A token to use as initializer word." + ) + parser.add_argument( + "--num_vectors", + type=int, + nargs='*', + help="Number of vectors per embedding." + ) + parser.add_argument( + "--num_class_images", + type=int, + default=1, + help="How many class images to generate." + ) + parser.add_argument( + "--class_image_dir", + type=str, + default="cls", + help="The directory where class images will be saved.", + ) + parser.add_argument( + "--exclude_collections", + type=str, + nargs='*', + help="Exclude all items with a listed collection.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output/text-inversion", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--embeddings_dir", + type=str, + default=None, + help="The embeddings directory where Textual Inversion embeddings are stored.", + ) + parser.add_argument( + "--collection", + type=str, + nargs='*', + help="A collection to filter the dataset.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training." + ) + parser.add_argument( + "--resolution", + type=int, + default=768, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--num_buckets", + type=int, + default=0, + help="Number of aspect ratio buckets in either direction.", + ) + parser.add_argument( + "--progressive_buckets", + action="store_true", + help="Include images in smaller buckets as well.", + ) + parser.add_argument( + "--bucket_step_size", + type=int, + default=64, + help="Step size between buckets.", + ) + parser.add_argument( + "--bucket_max_pixels", + type=int, + default=None, + help="Maximum pixels per bucket.", + ) + parser.add_argument( + "--tag_dropout", + type=float, + default=0, + help="Tag dropout probability.", + ) + parser.add_argument( + "--no_tag_shuffle", + action="store_true", + help="Shuffle tags.", + ) + parser.add_argument( + "--vector_dropout", + type=int, + default=0, + help="Vector dropout probability.", + ) + parser.add_argument( + "--vector_shuffle", + type=str, + default="auto", + help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=100 + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--find_lr", + action="store_true", + help="Automatically find a learning rate (no training).", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="one_cycle", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup", "one_cycle"]' + ), + ) + parser.add_argument( + "--lr_warmup_epochs", + type=int, + default=10, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_cycles", + type=int, + default=None, + help="Number of restart cycles in the lr scheduler." + ) + parser.add_argument( + "--lr_warmup_func", + type=str, + default="cos", + help='Choose between ["linear", "cos"]' + ) + parser.add_argument( + "--lr_warmup_exp", + type=int, + default=1, + help='If lr_warmup_func is "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_annealing_func", + type=str, + default="cos", + help='Choose between ["linear", "half_cos", "cos"]' + ) + parser.add_argument( + "--lr_annealing_exp", + type=int, + default=1, + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_min_lr", + type=float, + default=0.04, + help="Minimum learning rate in the lr scheduler." + ) + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to use EMA model." + ) + parser.add_argument( + "--ema_inv_gamma", + type=float, + default=1.0 + ) + parser.add_argument( + "--ema_power", + type=float, + default=1 + ) + parser.add_argument( + "--ema_max_decay", + type=float, + default=0.9999 + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=0, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" + ) + parser.add_argument( + "--adam_amsgrad", + type=bool, + default=False, + help="Amsgrad value for the Adam optimizer" + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=5, + help="How often to save a checkpoint and sample image (in epochs)", + ) + parser.add_argument( + "--sample_frequency", + type=int, + default=1, + help="How often to save a checkpoint and sample image (in epochs)", + ) + parser.add_argument( + "--sample_image_size", + type=int, + default=768, + help="Size of sample images", + ) + parser.add_argument( + "--sample_batches", + type=int, + default=1, + help="Number of sample batches to generate per checkpoint", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="Number of samples to generate per batch", + ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) + parser.add_argument( + "--valid_set_repeat", + type=int, + default=1, + help="Times the images in the validation dataset are repeated." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_steps", + type=int, + default=20, + help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", + ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) + parser.add_argument( + "--emb_decay_target", + default=0.4, + type=float, + help="Embedding decay target." + ) + parser.add_argument( + "--emb_decay_factor", + default=0, + type=float, + help="Embedding decay factor." + ) + parser.add_argument( + "--emb_decay_start", + default=1e-4, + type=float, + help="Embedding decay start offset." + ) + parser.add_argument( + "--noise_timesteps", + type=int, + default=1000, + ) + parser.add_argument( + "--resume_from", + type=str, + default=None, + help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" + ) + parser.add_argument( + "--global_step", + type=int, + default=0, + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script." + ) + + args = parser.parse_args() + if args.config is not None: + args = load_config(args.config) + args = parser.parse_args(namespace=argparse.Namespace(**args)) + + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") + + if args.pretrained_model_name_or_path is None: + raise ValueError("You must specify --pretrained_model_name_or_path") + + if args.project is None: + raise ValueError("You must specify --project") + + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] + + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] + + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) + + if len(args.initializer_tokens) == 0: + raise ValueError("You must specify --initializer_tokens") + + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + + if args.num_vectors is None: + args.num_vectors = 1 + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) + + if len(args.placeholder_tokens) != len(args.num_vectors): + raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + + if isinstance(args.collection, str): + args.collection = [args.collection] + + if isinstance(args.exclude_collections, str): + args.exclude_collections = [args.exclude_collections] + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + return args + + +def main(): + args = parse_args() + + global_step_offset = args.global_step + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{output_dir}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + if args.seed is None: + args.seed = torch.random.seed() >> 32 + + set_seed(args.seed) + + save_args(output_dir, args) + + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( + args.pretrained_model_name_or_path) + + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors + ) + + if len(placeholder_token_ids) != 0: + initializer_token_id_lens = [len(id) for id in initializer_token_ids] + placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) + print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + if args.find_lr: + args.learning_rate = 1e-5 + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + def keyword_filter(item: VlpnDataItem): + cond1 = any( + keyword in part + for keyword in args.placeholder_tokens + for part in item.prompt + ) + cond3 = args.collection is None or args.collection in item.collection + cond4 = args.exclude_collections is None or not any( + collection in item.collection + for collection in args.exclude_collections + ) + return cond1 and cond3 and cond4 + + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, + valid_set_size=args.valid_set_size, + valid_set_repeat=args.valid_set_repeat, + seed=args.seed, + filter=keyword_filter, + dtype=weight_dtype + ) + datamodule.setup() + + train_dataloader = datamodule.train_dataloader + val_dataloader = datamodule.val_dataloader + + if args.num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + sample_scheduler, + datamodule.data_train, + args.sample_batch_size, + args.sample_image_size, + args.sample_steps + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + trainer = Trainer( + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + noise_scheduler=noise_scheduler, + sample_scheduler=sample_scheduler, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + dtype=weight_dtype, + ) + + trainer( + strategy_class=TextualInversionTrainingStrategy, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + prior_loss_weight=args.prior_loss_weight, + output_dir=output_dir, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + learning_rate=args.learning_rate, + sample_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed, + ) + + +if __name__ == "__main__": + main() -- cgit v1.2.3-54-g00ecf