import argparse import datetime import logging from functools import partial from pathlib import Path from typing import Union import math import warnings import torch import torch.utils.checkpoint import hidet from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers.models.attention_processor import AttnProcessor from timm.models import create_model import transformers import numpy as np from slugify import slugify from data.csv import VlpnDataModule, keyword_filter from models.clip.embeddings import patch_managed_embeddings from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args from util.files import load_config, load_embeddings_from_dir logger = get_logger(__name__) warnings.filterwarnings("ignore") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True # torch._dynamo.config.log_level = logging.WARNING hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.use_attention(True) hidet.torch.dynamo_config.search_space(0) def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A CSV file containing the training data.", ) parser.add_argument( "--train_data_template", type=str, nargs="*", default="template", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--auto_cycles", type=str, default="o", help="Cycles to run automatically." ) parser.add_argument( "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." ) parser.add_argument( "--placeholder_tokens", type=str, nargs="*", help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_tokens", type=str, nargs="*", help="A token to use as initializer word.", ) parser.add_argument( "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." ) parser.add_argument( "--initializer_noise", type=float, default=0, help="Noise to apply to the initializer word", ) parser.add_argument( "--alias_tokens", type=str, nargs="*", default=[], help="Tokens to create an alias for.", ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) parser.add_argument( "--sequential", action="store_true", ) parser.add_argument( "--num_class_images", type=int, default=0, help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--exclude_collections", type=str, nargs="*", help="Exclude all items with a listed collection.", ) parser.add_argument( "--output_dir", type=str, default="output/text-inversion", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--train_dir_embeddings", action="store_true", help="Train embeddings loaded from embeddings directory.", ) parser.add_argument( "--collection", type=str, nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--num_buckets", type=int, default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0, help="Tag dropout probability.", ) parser.add_argument( "--no_tag_shuffle", action="store_true", help="Shuffle tags.", ) parser.add_argument( "--vector_dropout", type=int, default=0, help="Vector dropout probability.", ) parser.add_argument( "--vector_shuffle", type=str, default="auto", choices=["all", "trailing", "leading", "between", "auto", "off"], help="Vector shuffling algorithm.", ) parser.add_argument( "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1.", ) parser.add_argument("--num_train_epochs", type=int, default=None) parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", choices=[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle", ], help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler.", ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", choices=["linear", "cos"], ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", choices=["linear", "half_cos", "cos"], ) parser.add_argument( "--lr_annealing_exp", type=int, default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, help="Minimum learning rate in the lr scheduler.", ) parser.add_argument( "--use_ema", action="store_true", help="Whether to use EMA model." ) parser.add_argument("--ema_inv_gamma", type=float, default=1.0) parser.add_argument("--ema_power", type=float, default=4 / 5) parser.add_argument("--ema_max_decay", type=float, default=0.9999) parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( "--schedule_sampler", type=str, default="uniform", choices=["uniform", "loss-second-moment"], help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, default="adan", choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, help="The d0 parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--compile_unet", action="store_true", help="Compile UNet with Torch Dynamo.", ) parser.add_argument( "--use_xformers", action="store_true", help="Use xformers.", ) parser.add_argument( "--checkpoint_frequency", type=int, default=999999, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--no_milestone_checkpoints", action="store_true", help="If checkpoints are saved on maximum accuracy", ) parser.add_argument( "--sample_num", type=int, default=None, help="How often to save a checkpoint and sample image (in number of samples)", ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image (in epochs)", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset.", ) parser.add_argument( "--train_set_pad", type=int, default=None, help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, help="The number to fill validation dataset items up to.", ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", type=int, default=10, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.", ) parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( "--emb_dropout", type=float, default=0, help="Embedding dropout probability.", ) parser.add_argument( "--use_emb_decay", action="store_true", help="Whether to use embedding decay." ) parser.add_argument( "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( "--emb_decay", default=1e2, type=float, help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--resume_from", type=str, default=None, help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)", ) parser.add_argument( "--global_step", type=int, default=0, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len( args.placeholder_tokens ) if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [ f"<*{i}>" for i in range(len(args.initializer_tokens)) ] if len(args.initializer_tokens) == 0: args.initializer_tokens = args.placeholder_tokens.copy() if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError( "--placeholder_tokens and --initializer_tokens must have the same number of items" ) if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( args.num_vectors ): raise ValueError( "--placeholder_tokens and --num_vectors must have the same number of items" ) if args.alias_tokens is None: args.alias_tokens = [] if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") if args.filter_tokens is None: args.filter_tokens = args.placeholder_tokens.copy() if isinstance(args.filter_tokens, str): args.filter_tokens = [args.filter_tokens] if args.sequential: args.alias_tokens += [ item for pair in zip(args.placeholder_tokens, args.initializer_tokens) for item in pair ] if isinstance(args.train_data_template, str): args.train_data_template = [args.train_data_template] * len( args.placeholder_tokens ) if len(args.placeholder_tokens) != len(args.train_data_template): raise ValueError( "--placeholder_tokens and --train_data_template must have the same number of items" ) if args.num_vectors is None: args.num_vectors = [None] * len(args.placeholder_tokens) else: if isinstance(args.train_data_template, list): raise ValueError( "--train_data_template can't be a list in simultaneous mode" ) if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: if args.optimizer == "lion": args.adam_beta1 = 0.95 else: args.adam_beta1 = 0.9 if args.adam_beta2 is None: if args.optimizer == "lion": args.adam_beta2 = 0.98 else: args.adam_beta2 = 0.999 return args def main(): args = parse_args() global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( args.pretrained_model_name_or_path ) embeddings = patch_managed_embeddings( text_encoder, args.emb_alpha, args.emb_dropout ) schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, noise_scheduler.config.num_train_timesteps ) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) vae.enable_slicing() if args.use_xformers: vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() elif args.compile_unet: unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False proc = AttnProcessor() def fn_recursive_set_proc(module: torch.nn.Module): if hasattr(module, "processor"): module.processor = proc for child in module.children(): fn_recursive_set_proc(child) fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] alias_initializer_tokens = args.alias_tokens[1::2] added_tokens, added_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=alias_placeholder_tokens, initializer_tokens=alias_initializer_tokens, ) embeddings.persist() print( f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" ) if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir( tokenizer, embeddings, embeddings_dir ) print( f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" ) if args.train_dir_embeddings: args.placeholder_tokens = added_tokens print("Training embeddings from embeddings dir") else: embeddings.persist() if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-5 args.lr_scheduler = "exponential_growth" if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) create_optimizer = partial( bnb.optim.AdamW8bit, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) elif args.optimizer == "adan": try: import timm.optim except ImportError: raise ImportError( "To use Adan, please install the PyTorch Image Models library: `pip install timm`." ) create_optimizer = partial( timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, no_prox=True, ) elif args.optimizer == "lion": try: import lion_pytorch except ImportError: raise ImportError( "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." ) create_optimizer = partial( lion_pytorch.Lion, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, use_triton=True, ) elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, scale_parameter=True, relative_step=True, warmup_init=True, ) args.lr_scheduler = "adafactor" args.lr_min_lr = args.learning_rate args.learning_rate = None elif args.optimizer == "dadam": try: import dadaptation except ImportError: raise ImportError( "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." ) create_optimizer = partial( dadaptation.DAdaptAdam, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, d0=args.dadaptation_d0, ) elif args.optimizer == "dadan": try: import dadaptation except ImportError: raise ImportError( "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." ) create_optimizer = partial( dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, d0=args.dadaptation_d0, ) else: raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, schedule_sampler=schedule_sampler, min_snr_gamma=args.min_snr_gamma, dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, strategy=textual_inversion_strategy, gradient_accumulation_steps=args.gradient_accumulation_steps, checkpoint_frequency=args.checkpoint_frequency, milestone_checkpoints=not args.no_milestone_checkpoints, global_step_offset=global_step_offset, input_pertubation=args.input_pertubation, # -- use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, emb_decay=args.emb_decay, use_ema=args.use_ema, ema_inv_gamma=args.ema_inv_gamma, ema_power=args.ema_power, ema_max_decay=args.ema_max_decay, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, ) optimizer = create_optimizer( text_encoder.text_model.embeddings.token_embedding.parameters(), lr=args.learning_rate, ) data_generator = torch.Generator(device="cpu").manual_seed(args.seed) data_npgenerator = np.random.default_rng(args.seed) def run( i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str, ): placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=placeholder_tokens, initializer_tokens=initializer_tokens, num_vectors=num_vectors, initializer_noise=args.initializer_noise, ) stats = list( zip( placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids, ) ) print("") print(f"============ TI batch {i + 1} ============") print("") print(stats) filter_tokens = [ token for token in args.filter_tokens if token in placeholder_tokens ] datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, tag_dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=data_template, placeholder_tokens=args.placeholder_tokens, valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, filter=partial( keyword_filter, filter_tokens, args.collection, args.exclude_collections ), dtype=weight_dtype, generator=data_generator, npgenerator=data_npgenerator, ) datamodule.setup() num_train_epochs = args.num_train_epochs sample_frequency = args.sample_frequency if num_train_epochs is None: num_train_epochs = ( math.ceil(args.num_train_steps / len(datamodule.train_dataset)) * args.gradient_accumulation_steps ) sample_frequency = math.ceil( num_train_epochs * (sample_frequency / args.num_train_steps) ) num_training_steps_per_epoch = math.ceil( len(datamodule.train_dataset) / args.gradient_accumulation_steps ) num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: sample_frequency = math.ceil(num_train_epochs / args.sample_num) project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" if accelerator.is_main_process: accelerator.init_trackers(project) sample_output_dir = output_dir / project / "samples" training_iter = 0 auto_cycles = list(args.auto_cycles) learning_rate = args.learning_rate lr_scheduler = args.lr_scheduler lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles avg_loss = AverageMeter() avg_acc = AverageMeter() avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) else: response = input( "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " ) if response.lower().strip() == "o": if args.learning_rate is not None: learning_rate = args.learning_rate * 2 else: learning_rate = args.learning_rate if response.lower().strip() == "o": lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles elif response.lower().strip() == "w": lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs elif response.lower().strip() == "c": lr_scheduler = "constant" elif response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 lr_cycles = 1 elif response.lower().strip() == "s": break else: continue print("") print(f"------------ TI cycle {training_iter + 1}: {response} ------------") print("") for group, lr in zip(optimizer.param_groups, [learning_rate]): group["lr"] = lr lr_scheduler = get_scheduler( lr_scheduler, optimizer=optimizer, num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=lr_cycles, end_lr=1e3, train_epochs=num_train_epochs, warmup_epochs=lr_warmup_epochs, mid_point=args.lr_mid_point, ) checkpoint_output_dir = ( output_dir / project / f"checkpoints_{training_iter}" ) trainer( train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=num_train_epochs, global_step_offset=training_iter * num_train_steps, cycle=training_iter, # -- group_labels=["emb"], checkpoint_output_dir=checkpoint_output_dir, sample_output_dir=sample_output_dir, sample_frequency=sample_frequency, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, avg_loss=avg_loss, avg_acc=avg_acc, avg_loss_val=avg_loss_val, avg_acc_val=avg_acc_val, ) training_iter += 1 if learning_rate is not None: learning_rate *= args.cycle_decay accelerator.end_training() if not args.sequential: run( 0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template, ) else: for i, placeholder_token, initializer_token, num_vectors, data_template in zip( range(len(args.placeholder_tokens)), args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template, ): run(i, [placeholder_token], [initializer_token], num_vectors, data_template) embeddings.persist() if __name__ == "__main__": main()