import argparse import torch import torch.utils.checkpoint from accelerate.logging import get_logger from util import load_config from data.csv import VlpnDataItem from training.common import train_setup from training.modules.ti import train_ti from training.util import save_args logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A CSV file containing the training data." ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--placeholder_token", type=str, nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_token", type=str, nargs='*', help="A token to use as initializer word." ) parser.add_argument( "--num_vectors", type=int, nargs='*', help="Number of vectors per embedding." ) parser.add_argument( "--num_class_images", type=int, default=1, help="How many class images to generate." ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--exclude_collections", type=str, nargs='*', help="Exclude all items with a listed collection.", ) parser.add_argument( "--output_dir", type=str, default="output/text-inversion", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs='*', help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--num_buckets", type=int, default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0, help="Tag dropout probability.", ) parser.add_argument( "--no_tag_shuffle", action="store_true", help="Shuffle tags.", ) parser.add_argument( "--vector_dropout", type=int, default=0, help="Vector dropout probability.", ) parser.add_argument( "--vector_shuffle", type=str, default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" " process." ), ) parser.add_argument( "--num_train_epochs", type=int, default=100 ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup", "one_cycle"]' ), ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler." ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", help='Choose between ["linear", "cos"]' ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function' ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", help='Choose between ["linear", "half_cos", "cos"]' ) parser.add_argument( "--lr_annealing_exp", type=int, default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( "--use_ema", action="store_true", help="Whether to use EMA model." ) parser.add_argument( "--ema_inv_gamma", type=float, default=1.0 ) parser.add_argument( "--ema_power", type=float, default=4/5 ) parser.add_argument( "--ema_max_decay", type=float, default=0.9999 ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." ) parser.add_argument( "--adam_weight_decay", type=float, default=0, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer" ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--checkpoint_frequency", type=int, default=5, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset." ) parser.add_argument( "--valid_set_repeat", type=int, default=1, help="Times the images in the validation dataset are repeated." ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_steps", type=int, default=20, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss." ) parser.add_argument( "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( "--emb_decay_factor", default=1, type=float, help="Embedding decay factor." ) parser.add_argument( "--emb_decay_start", default=1e-4, type=float, help="Embedding decay start offset." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--resume_from", type=str, default=None, help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" ) parser.add_argument( "--global_step", type=int, default=0, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.placeholder_token, str): args.placeholder_token = [args.placeholder_token] if len(args.placeholder_token) == 0: args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] if isinstance(args.initializer_token, str): args.initializer_token = [args.initializer_token] * len(args.placeholder_token) if len(args.initializer_token) == 0: raise ValueError("You must specify --initializer_token") if len(args.placeholder_token) != len(args.initializer_token): raise ValueError("--placeholder_token and --initializer_token must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_token) if len(args.placeholder_token) != len(args.num_vectors): raise ValueError("--placeholder_token and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") return args def main(): args = parse_args() def data_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_token for part in item.prompt ) cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection for collection in args.exclude_collections ) return cond1 and cond3 and cond4 setup = train_setup( output_dir=args.output_dir, project=args.project, pretrained_model_name_or_path=args.pretrained_model_name_or_path, learning_rate=args.learning_rate, data_file=args.train_data_file, gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, seed=args.seed, vector_shuffle=args.vector_shuffle, vector_dropout=args.vector_dropout, gradient_checkpointing=args.gradient_checkpointing, embeddings_dir=args.embeddings_dir, placeholder_token=args.placeholder_token, initializer_token=args.initializer_token, num_vectors=args.num_vectors, scale_lr=args.scale_lr, use_8bit_adam=args.use_8bit_adam, train_batch_size=args.train_batch_size, class_image_dir=args.class_image_dir, num_class_images=args.num_class_images, resolution=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, tag_dropout=args.tag_dropout, tag_shuffle=not args.no_tag_shuffle, data_template=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, data_filter=data_filter, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_steps=args.sample_steps, ) save_args(setup.output_dir, args) train_ti( setup=setup, num_train_epochs=args.num_train_epochs, num_class_images=args.num_class_images, prior_loss_weight=args.prior_loss_weight, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, ema_max_decay=args.ema_max_decay, adam_beta1=args.adam_beta1, adam_beta2=args.adam_beta2, adam_weight_decay=args.adam_weight_decay, adam_epsilon=args.adam_epsilon, adam_amsgrad=args.adam_amsgrad, lr_scheduler=args.lr_scheduler, lr_min_lr=args.lr_min_lr, lr_warmup_func=args.lr_warmup_func, lr_annealing_func=args.lr_annealing_func, lr_warmup_exp=args.lr_warmup_exp, lr_annealing_exp=args.lr_annealing_exp, lr_cycles=args.lr_cycles, lr_warmup_epochs=args.lr_warmup_epochs, emb_decay_target=args.emb_decay_target, emb_decay_factor=args.emb_decay_factor, emb_decay_start=args.emb_decay_start, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, sample_frequency=args.sample_frequency, sample_steps=args.sample_steps, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=args.global_step, ) if __name__ == "__main__": main()