import argparse import datetime import logging import itertools from pathlib import Path from functools import partial import math import warnings import torch import torch._dynamo import torch.utils.checkpoint import hidet from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed # from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.import_utils import is_xformers_available import transformers import numpy as np from slugify import slugify from data.csv import VlpnDataModule, keyword_filter from models.clip.embeddings import patch_managed_embeddings from training.functional import train, add_placeholder_tokens, get_models from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler from training.sampler import create_named_schedule_sampler from training.util import AverageMeter, save_args from util.files import load_config, load_embeddings_from_dir logger = get_logger(__name__) warnings.filterwarnings("ignore") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True # torch._dynamo.config.log_level = logging.WARNING torch._dynamo.config.suppress_errors = True hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) def patch_xformers(dtype): if is_xformers_available(): import xformers import xformers.ops orig_xformers_memory_efficient_attention = ( xformers.ops.memory_efficient_attention ) def xformers_memory_efficient_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs ): return orig_xformers_memory_efficient_attention( query.to(dtype), key.to(dtype), value.to(dtype), **kwargs ) xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A folder containing the training data.", ) parser.add_argument( "--train_data_template", type=str, default="template", ) parser.add_argument( "--train_set_pad", type=int, default=None, help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, help="The number to fill validation dataset items up to.", ) parser.add_argument( "--project", type=str, default=None, help="The name of the current project.", ) parser.add_argument( "--auto_cycles", type=str, default="o", help="Cycles to run automatically." ) parser.add_argument( "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." ) parser.add_argument( "--placeholder_tokens", type=str, nargs="*", help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_tokens", type=str, nargs="*", help="A token to use as initializer word.", ) parser.add_argument( "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." ) parser.add_argument( "--initializer_noise", type=float, default=0, help="Noise to apply to the initializer word", ) parser.add_argument( "--alias_tokens", type=str, nargs="*", default=[], help="Tokens to create an alias for.", ) parser.add_argument( "--inverted_initializer_tokens", type=str, nargs="*", help="A token to use as initializer word.", ) parser.add_argument( "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) parser.add_argument( "--exclude_collections", type=str, nargs="*", help="Exclude all items with a listed collection.", ) parser.add_argument( "--num_buckets", type=int, default=2, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) parser.add_argument( "--bucket_step_size", type=int, default=64, help="Step size between buckets.", ) parser.add_argument( "--bucket_max_pixels", type=int, default=None, help="Maximum pixels per bucket.", ) parser.add_argument( "--tag_dropout", type=float, default=0, help="Tag dropout probability.", ) parser.add_argument( "--prompt_dropout", type=float, default=0, help="Prompt dropout probability.", ) parser.add_argument( "--no_tag_shuffle", action="store_true", help="Shuffle tags.", ) parser.add_argument( "--num_class_images", type=int, default=0, help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", type=str, default="cls", help="The directory where class images will be saved.", ) parser.add_argument( "--output_dir", type=str, default="output/dreambooth", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--embeddings_dir", type=str, default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) parser.add_argument( "--collection", type=str, nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", type=int, default=768, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1.", ) parser.add_argument("--num_train_epochs", type=int, default=None) parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--train_text_encoder_cycles", default=999999, help="Number of epochs the text encoder will be trained.", ) parser.add_argument( "--text_encoder_unfreeze_last_n_layers", default=-1, help="Number of text encoder layers to train.", ) parser.add_argument( "--find_lr", action="store_true", help="Automatically find a learning rate (no training).", ) parser.add_argument( "--learning_rate_unet", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--learning_rate_text", type=float, default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="one_cycle", choices=[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle", ], help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, help="Number of restart cycles in the lr scheduler (if supported).", ) parser.add_argument( "--lr_warmup_func", type=str, default="cos", choices=["linear", "cos"], ) parser.add_argument( "--lr_warmup_exp", type=int, default=1, help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", type=str, default="cos", choices=["linear", "half_cos", "cos"], ) parser.add_argument( "--lr_annealing_exp", type=int, default=3, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, help="Minimum learning rate in the lr scheduler.", ) parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( "--schedule_sampler", type=str, default="uniform", choices=["uniform", "loss-second-moment"], help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, default="adan", choices=[ "adam", "adam8bit", "adan", "lion", "dadam", "dadan", "dlion", "adafactor", ], help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, help="The d0 parameter for Dadaptation optimizers.", ) parser.add_argument( "--dadaptation_growth_rate", type=float, default=math.inf, help="The growth_rate parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." "and an Nvidia Ampere GPU." ), ) parser.add_argument( "--compile_unet", action="store_true", help="Compile UNet with Torch Dynamo.", ) parser.add_argument( "--use_xformers", action="store_true", help="Use xformers.", ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image", ) parser.add_argument( "--sample_num", type=int, default=None, help="How often to save a checkpoint and sample image (in number of samples)", ) parser.add_argument( "--sample_image_size", type=int, default=768, help="Size of sample images", ) parser.add_argument( "--sample_batches", type=int, default=1, help="Number of sample batches to generate per checkpoint", ) parser.add_argument( "--sample_batch_size", type=int, default=1, help="Number of samples to generate per batch", ) parser.add_argument( "--valid_set_size", type=int, default=None, help="Number of images in the validation dataset.", ) parser.add_argument( "--valid_set_repeat", type=int, default=1, help="Times the images in the validation dataset are repeated.", ) parser.add_argument( "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", type=int, default=10, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.", ) parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( "--emb_dropout", type=float, default=0, help="Embedding dropout probability.", ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", type=int, default=1000, ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() if args.config is not None: args = load_config(args.config) args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.pretrained_model_name_or_path is None: raise ValueError("You must specify --pretrained_model_name_or_path") if args.project is None: raise ValueError("You must specify --project") if args.initializer_tokens is None: args.initializer_tokens = [] if args.placeholder_tokens is None: args.placeholder_tokens = [] if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len( args.placeholder_tokens ) if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [ f"<*{i}>" for i in range(len(args.initializer_tokens)) ] if len(args.initializer_tokens) == 0: args.initializer_tokens = args.placeholder_tokens.copy() if len(args.placeholder_tokens) != len(args.initializer_tokens): raise ValueError( "--placeholder_tokens and --initializer_tokens must have the same number of items" ) if isinstance(args.inverted_initializer_tokens, str): args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( args.placeholder_tokens ) if ( isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0 ): args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] args.initializer_tokens += args.inverted_initializer_tokens if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( args.num_vectors ): raise ValueError( "--placeholder_tokens and --num_vectors must have the same number of items" ) if args.alias_tokens is None: args.alias_tokens = [] if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") if args.filter_tokens is None: args.filter_tokens = args.placeholder_tokens.copy() if isinstance(args.filter_tokens, str): args.filter_tokens = [args.filter_tokens] if isinstance(args.collection, str): args.collection = [args.collection] if isinstance(args.exclude_collections, str): args.exclude_collections = [args.exclude_collections] if args.output_dir is None: raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta1 = 0.9 elif args.optimizer in ("lion", "dlion"): args.adam_beta1 = 0.95 if args.adam_beta2 is None: if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta2 = 0.999 elif args.optimizer in ("lion", "dlion"): args.adam_beta2 = 0.98 return args def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 patch_xformers(weight_dtype) logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( args.pretrained_model_name_or_path ) embeddings = patch_managed_embeddings( text_encoder, args.emb_alpha, args.emb_dropout ) schedule_sampler = create_named_schedule_sampler( args.schedule_sampler, noise_scheduler.config.num_train_timesteps ) vae.enable_slicing() if args.use_xformers: vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() # elif args.compile_unet: # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False # # proc = AttnProcessor() # # def fn_recursive_set_proc(module: torch.nn.Module): # if hasattr(module, "processor"): # module.processor = proc # # for child in module.children(): # fn_recursive_set_proc(child) # # fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() if len(args.alias_tokens) != 0: alias_placeholder_tokens = args.alias_tokens[::2] alias_initializer_tokens = args.alias_tokens[1::2] added_tokens, added_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=alias_placeholder_tokens, initializer_tokens=alias_initializer_tokens, ) embeddings.persist() print( f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" ) placeholder_tokens = [] placeholder_token_ids = [] if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir( tokenizer, embeddings, embeddings_dir ) print( f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" ) embeddings.persist() if len(args.placeholder_tokens) != 0: placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=args.placeholder_tokens, initializer_tokens=args.initializer_tokens, num_vectors=args.num_vectors, initializer_noise=args.initializer_noise, ) placeholder_tokens = args.placeholder_tokens stats = list( zip( placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids, ) ) print(f"Training embeddings: {stats}") if args.scale_lr: args.learning_rate_unet = ( args.learning_rate_unet * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) args.learning_rate_text = ( args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate_unet = 1e-6 args.learning_rate_text = 1e-6 args.lr_scheduler = "exponential_growth" if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) create_optimizer = partial( bnb.optim.AdamW8bit, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) elif args.optimizer == "adan": try: import timm.optim except ImportError: raise ImportError( "To use Adan, please install the PyTorch Image Models library: `pip install timm`." ) create_optimizer = partial( timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, no_prox=True, ) elif args.optimizer == "lion": try: import lion_pytorch except ImportError: raise ImportError( "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." ) create_optimizer = partial( lion_pytorch.Lion, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, use_triton=True, ) elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, scale_parameter=True, relative_step=True, warmup_init=True, ) args.lr_scheduler = "adafactor" args.lr_min_lr = args.learning_rate_unet args.learning_rate_unet = None args.learning_rate_text = None elif args.optimizer == "dadam": try: import dadaptation except ImportError: raise ImportError( "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." ) create_optimizer = partial( dadaptation.DAdaptAdam, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=True, d0=args.dadaptation_d0, growth_rate=args.dadaptation_growth_rate, ) args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 elif args.optimizer == "dadan": try: import dadaptation except ImportError: raise ImportError( "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." ) create_optimizer = partial( dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, d0=args.dadaptation_d0, growth_rate=args.dadaptation_growth_rate, ) args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 elif args.optimizer == "dlion": raise ImportError("DLion has not been merged into dadaptation yet") else: raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, schedule_sampler=schedule_sampler, min_snr_gamma=args.min_snr_gamma, dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, sample_num_batches=args.sample_batches, sample_num_steps=args.sample_steps, sample_image_size=args.sample_image_size, max_grad_norm=args.max_grad_norm, ) data_generator = torch.Generator(device="cpu").manual_seed(args.seed) data_npgenerator = np.random.default_rng(args.seed) create_datamodule = partial( VlpnDataModule, data_file=args.train_data_file, tokenizer=tokenizer, constant_prompt_length=args.compile_unet, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, dtype=weight_dtype, generator=data_generator, npgenerator=data_npgenerator, ) create_lr_scheduler = partial( get_scheduler, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, end_lr=1e2, mid_point=args.lr_mid_point, ) # Dreambooth # -------------------------------------------------------------------------------- dreambooth_datamodule = create_datamodule( valid_set_size=args.valid_set_size, batch_size=args.train_batch_size, tag_dropout=args.tag_dropout, prompt_dropout=args.prompt_dropout, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), ) dreambooth_datamodule.setup() num_train_epochs = args.num_train_epochs dreambooth_sample_frequency = args.sample_frequency if num_train_epochs is None: num_train_epochs = ( math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) * args.gradient_accumulation_steps ) dreambooth_sample_frequency = math.ceil( num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) ) num_training_steps_per_epoch = math.ceil( len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps ) num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) dreambooth_project = "dreambooth" if accelerator.is_main_process: accelerator.init_trackers(dreambooth_project) dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" training_iter = 0 auto_cycles = list(args.auto_cycles) learning_rate_unet = args.learning_rate_unet learning_rate_text = args.learning_rate_text lr_scheduler = args.lr_scheduler lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles avg_loss = AverageMeter() avg_acc = AverageMeter() avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() params_to_optimize = [ { "params": (param for param in unet.parameters() if param.requires_grad), "lr": learning_rate_unet, }, { "params": ( param for param in text_encoder.parameters() if param.requires_grad ), "lr": learning_rate_text, }, ] group_labels = ["unet", "text"] dreambooth_optimizer = create_optimizer(params_to_optimize) while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) else: response = input( "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " ) if response.lower().strip() == "o": if args.learning_rate_unet is not None: learning_rate_unet = ( args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) ) if args.learning_rate_text is not None: learning_rate_text = ( args.learning_rate_text * 2 * (args.cycle_decay**training_iter) ) else: learning_rate_unet = args.learning_rate_unet * ( args.cycle_decay**training_iter ) learning_rate_text = args.learning_rate_text * ( args.cycle_decay**training_iter ) if response.lower().strip() == "o": lr_scheduler = "one_cycle" lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles elif response.lower().strip() == "w": lr_scheduler = "constant_with_warmup" lr_warmup_epochs = num_train_epochs elif response.lower().strip() == "c": lr_scheduler = "constant" elif response.lower().strip() == "d": lr_scheduler = "cosine" lr_warmup_epochs = 0 lr_cycles = 1 elif response.lower().strip() == "s": break else: continue print("") print( f"============ Dreambooth cycle {training_iter + 1}: {response} ============" ) print("") for group, lr in zip( dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] ): group["lr"] = lr dreambooth_lr_scheduler = create_lr_scheduler( lr_scheduler, gradient_accumulation_steps=args.gradient_accumulation_steps, optimizer=dreambooth_optimizer, num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), train_epochs=num_train_epochs, cycles=lr_cycles, warmup_epochs=lr_warmup_epochs, ) dreambooth_checkpoint_output_dir = ( output_dir / dreambooth_project / f"model_{training_iter}" ) trainer( strategy=dreambooth_strategy, train_dataloader=dreambooth_datamodule.train_dataloader, val_dataloader=dreambooth_datamodule.val_dataloader, optimizer=dreambooth_optimizer, lr_scheduler=dreambooth_lr_scheduler, num_train_epochs=num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, global_step_offset=training_iter * num_train_steps, cycle=training_iter, train_text_encoder_cycles=args.train_text_encoder_cycles, # -- group_labels=group_labels, sample_output_dir=dreambooth_sample_output_dir, checkpoint_output_dir=dreambooth_checkpoint_output_dir, sample_frequency=dreambooth_sample_frequency, input_pertubation=args.input_pertubation, text_encoder_unfreeze_last_n_layers=args.text_encoder_unfreeze_last_n_layers, no_val=args.valid_set_size == 0, avg_loss=avg_loss, avg_acc=avg_acc, avg_loss_val=avg_loss_val, avg_acc_val=avg_acc_val, ) training_iter += 1 accelerator.end_training() if __name__ == "__main__": main()