From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- train_dreambooth.py | 770 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 568 insertions(+), 202 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 2aca1e7..659b84c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -5,34 +5,70 @@ import itertools from pathlib import Path from functools import partial import math +import warnings import torch +import torch._dynamo import torch.utils.checkpoint +import hidet from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from slugify import slugify + +# from diffusers.models.attention_processor import AttnProcessor +from diffusers.utils.import_utils import is_xformers_available import transformers -from util.files import load_config, load_embeddings_from_dir +import numpy as np +from slugify import slugify + from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, get_models +from models.clip.embeddings import patch_managed_embeddings +from training.functional import train, add_placeholder_tokens, get_models from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler -from training.util import save_args +from training.sampler import create_named_schedule_sampler +from training.util import AverageMeter, save_args +from util.files import load_config, load_embeddings_from_dir + logger = get_logger(__name__) +warnings.filterwarnings("ignore") + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True +# torch._dynamo.config.log_level = logging.WARNING +torch._dynamo.config.suppress_errors = True + +hidet.torch.dynamo_config.use_tensor_core(True) +hidet.torch.dynamo_config.search_space(0) + + +def patch_xformers(dtype): + if is_xformers_available(): + import xformers + import xformers.ops + + orig_xformers_memory_efficient_attention = ( + xformers.ops.memory_efficient_attention + ) + + def xformers_memory_efficient_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs + ): + return orig_xformers_memory_efficient_attention( + query.to(dtype), key.to(dtype), value.to(dtype), **kwargs + ) + + xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention + def parse_args(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -49,7 +85,7 @@ def parse_args(): "--train_data_file", type=str, default=None, - help="A folder containing the training data." + help="A folder containing the training data.", ) parser.add_argument( "--train_data_template", @@ -60,13 +96,13 @@ def parse_args(): "--train_set_pad", type=int, default=None, - help="The number to fill train dataset items up to." + help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, - help="The number to fill validation dataset items up to." + help="The number to fill validation dataset items up to.", ) parser.add_argument( "--project", @@ -75,20 +111,58 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--exclude_collections", + "--auto_cycles", type=str, default="o", help="Cycles to run automatically." + ) + parser.add_argument( + "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." + ) + parser.add_argument( + "--placeholder_tokens", type=str, - nargs='*', - help="Exclude all items with a listed collection.", + nargs="*", + help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--train_text_encoder_epochs", - default=999999, - help="Number of epochs the text encoder will be trained." + "--initializer_tokens", + type=str, + nargs="*", + help="A token to use as initializer word.", + ) + parser.add_argument( + "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." + ) + parser.add_argument( + "--initializer_noise", + type=float, + default=0, + help="Noise to apply to the initializer word", + ) + parser.add_argument( + "--alias_tokens", + type=str, + nargs="*", + default=[], + help="Tokens to create an alias for.", + ) + parser.add_argument( + "--inverted_initializer_tokens", + type=str, + nargs="*", + help="A token to use as initializer word.", + ) + parser.add_argument( + "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." + ) + parser.add_argument( + "--exclude_collections", + type=str, + nargs="*", + help="Exclude all items with a listed collection.", ) parser.add_argument( "--num_buckets", type=int, - default=0, + default=2, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -119,19 +193,6 @@ def parse_args(): action="store_true", help="Shuffle tags.", ) - parser.add_argument( - "--vector_dropout", - type=int, - default=0, - help="Vector dropout probability.", - ) - parser.add_argument( - "--vector_shuffle", - type=str, - default="auto", - choices=["all", "trailing", "leading", "between", "auto", "off"], - help='Vector shuffling algorithm.', - ) parser.add_argument( "--guidance_scale", type=float, @@ -141,7 +202,7 @@ def parse_args(): "--num_class_images", type=int, default=0, - help="How many class images to generate." + help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", @@ -161,17 +222,19 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--train_dir_embeddings", + action="store_true", + help="Train embeddings loaded from embeddings directory.", + ) parser.add_argument( "--collection", type=str, - nargs='*', + nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( - "--seed", - type=int, - default=None, - help="A seed for reproducible training." + "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", @@ -189,15 +252,13 @@ def parse_args(): help="Perlin offset noise strength.", ) parser.add_argument( - "--num_train_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_train_steps", - type=int, - default=2000 + "--input_pertubation", + type=float, + default=0, + help="The scale of input pretubation. Recommended 0.1.", ) + parser.add_argument("--num_train_epochs", type=int, default=None) + parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -205,9 +266,9 @@ def parse_args(): help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + "--train_text_encoder_cycles", + default=999999, + help="Number of epochs the text encoder will be trained.", ) parser.add_argument( "--find_lr", @@ -215,9 +276,15 @@ def parse_args(): help="Automatically find a learning rate (no training).", ) parser.add_argument( - "--learning_rate", + "--learning_rate_unet", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_text", type=float, - default=2e-6, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -229,27 +296,31 @@ def parse_args(): "--lr_scheduler", type=str, default="one_cycle", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup", "one_cycle"], - help='The scheduler type to use.', + choices=[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "one_cycle", + ], + help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, - help="Number of steps for the warmup in the lr scheduler." + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--lr_mid_point", - type=float, - default=0.3, - help="OneCycle schedule mid point." + "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, - help="Number of restart cycles in the lr scheduler (if supported)." + help="Number of restart cycles in the lr scheduler (if supported).", ) parser.add_argument( "--lr_warmup_func", @@ -261,7 +332,7 @@ def parse_args(): "--lr_warmup_exp", type=int, default=1, - help='If lr_warmup_func is "cos", exponent to modify the function' + help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", @@ -273,76 +344,76 @@ def parse_args(): "--lr_annealing_exp", type=int, default=3, - help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, - help="Minimum learning rate in the lr scheduler." - ) - parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 + help="Minimum learning rate in the lr scheduler.", ) + parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( - "--ema_power", - type=float, - default=6/7 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 + "--schedule_sampler", + type=str, + default="uniform", + choices=["uniform", "loss-second-moment"], + help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, - default="dadan", - choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], - help='Optimizer to use' + default="adan", + choices=[ + "adam", + "adam8bit", + "adan", + "lion", + "dadam", + "dadan", + "dlion", + "adafactor", + ], + help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, - help="The d0 parameter for Dadaptation optimizers." + help="The d0 parameter for Dadaptation optimizers.", + ) + parser.add_argument( + "--dadaptation_growth_rate", + type=float, + default=math.inf, + help="The growth_rate parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, - help="The beta1 parameter for the Adam optimizer." + help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, - help="The beta2 parameter for the Adam optimizer." + help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-2, - help="Weight decay to use." + "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer" + help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, - help="Amsgrad value for the Adam optimizer" + help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", @@ -355,12 +426,28 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--compile_unet", + action="store_true", + help="Compile UNet with Torch Dynamo.", + ) + parser.add_argument( + "--use_xformers", + action="store_true", + help="Use xformers.", + ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image", ) + parser.add_argument( + "--sample_num", + type=int, + default=None, + help="How often to save a checkpoint and sample image (in number of samples)", + ) parser.add_argument( "--sample_image_size", type=int, @@ -383,19 +470,19 @@ def parse_args(): "--valid_set_size", type=int, default=None, - help="Number of images in the validation dataset." + help="Number of images in the validation dataset.", ) parser.add_argument( "--valid_set_repeat", type=int, default=1, - help="Times the images in the validation dataset are repeated." + help="Times the images in the validation dataset are repeated.", ) parser.add_argument( "--train_batch_size", type=int, default=1, - help="Batch size (per device) for the training dataloader." + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", @@ -407,13 +494,18 @@ def parse_args(): "--prior_loss_weight", type=float, default=1.0, - help="The weight of prior preservation loss." + help="The weight of prior preservation loss.", ) + parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") + parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( - "--max_grad_norm", - default=1.0, + "--emb_dropout", type=float, - help="Max gradient norm." + default=0, + help="Embedding dropout probability.", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", @@ -424,7 +516,7 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script." + help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() @@ -441,6 +533,67 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") + if args.initializer_tokens is None: + args.initializer_tokens = [] + + if args.placeholder_tokens is None: + args.placeholder_tokens = [] + + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] + + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len( + args.placeholder_tokens + ) + + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [ + f"<*{i}>" for i in range(len(args.initializer_tokens)) + ] + + if len(args.initializer_tokens) == 0: + args.initializer_tokens = args.placeholder_tokens.copy() + + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError( + "--placeholder_tokens and --initializer_tokens must have the same number of items" + ) + + if isinstance(args.inverted_initializer_tokens, str): + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( + args.placeholder_tokens + ) + + if ( + isinstance(args.inverted_initializer_tokens, list) + and len(args.inverted_initializer_tokens) != 0 + ): + args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] + args.initializer_tokens += args.inverted_initializer_tokens + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) + + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( + args.num_vectors + ): + raise ValueError( + "--placeholder_tokens and --num_vectors must have the same number of items" + ) + + if args.alias_tokens is None: + args.alias_tokens = [] + + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: + raise ValueError("--alias_tokens must be a list with an even number of items") + + if args.filter_tokens is None: + args.filter_tokens = args.placeholder_tokens.copy() + + if isinstance(args.filter_tokens, str): + args.filter_tokens = [args.filter_tokens] + if isinstance(args.collection, str): args.collection = [args.collection] @@ -451,15 +604,15 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer in ('adam', 'adam8bit'): + if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta1 = 0.9 - elif args.optimizer == 'lion': + elif args.optimizer in ("lion", "dlion"): args.adam_beta1 = 0.95 if args.adam_beta2 is None: - if args.optimizer in ('adam', 'adam8bit'): + if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta2 = 0.999 - elif args.optimizer == 'lion': + elif args.optimizer in ("lion", "dlion"): args.adam_beta2 = 0.98 return args @@ -475,7 +628,7 @@ def main(): accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", - mixed_precision=args.mixed_precision + mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 @@ -484,6 +637,8 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + patch_xformers(weight_dtype) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: @@ -493,44 +648,125 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) - - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - tokenizer.set_dropout(args.vector_dropout) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( + args.pretrained_model_name_or_path + ) + embeddings = patch_managed_embeddings( + text_encoder, args.emb_alpha, args.emb_dropout + ) + schedule_sampler = create_named_schedule_sampler( + args.schedule_sampler, noise_scheduler.config.num_train_timesteps + ) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() + + if args.use_xformers: + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() + # elif args.compile_unet: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False + # + # proc = AttnProcessor() + # + # def fn_recursive_set_proc(module: torch.nn.Module): + # if hasattr(module, "processor"): + # module.processor = proc + # + # for child in module.children(): + # fn_recursive_set_proc(child) + # + # fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() + + if len(args.alias_tokens) != 0: + alias_placeholder_tokens = args.alias_tokens[::2] + alias_initializer_tokens = args.alias_tokens[1::2] + + added_tokens, added_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=alias_placeholder_tokens, + initializer_tokens=alias_initializer_tokens, + ) + embeddings.persist() + print( + f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" + ) + + placeholder_tokens = [] + placeholder_token_ids = [] if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - embeddings.persist() - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + added_tokens, added_ids = load_embeddings_from_dir( + tokenizer, embeddings, embeddings_dir + ) + + placeholder_tokens = added_tokens + placeholder_token_ids = added_ids + + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) + + if args.train_dir_embeddings: + print("Training embeddings from embeddings dir") + else: + embeddings.persist() + + if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors, + initializer_noise=args.initializer_noise, + ) + + placeholder_tokens = args.placeholder_tokens + + stats = list( + zip( + placeholder_tokens, + placeholder_token_ids, + args.initializer_tokens, + initializer_token_ids, + ) + ) + print(f"Training embeddings: {stats}") if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate_unet = ( + args.learning_rate_unet + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes + ) + args.learning_rate_text = ( + args.learning_rate_text + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) if args.find_lr: - args.learning_rate = 1e-6 + args.learning_rate_unet = 1e-6 + args.learning_rate_text = 1e-6 args.lr_scheduler = "exponential_growth" - if args.optimizer == 'adam8bit': + if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) create_optimizer = partial( bnb.optim.AdamW8bit, @@ -539,7 +775,7 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adam': + elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), @@ -547,22 +783,27 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adan': + elif args.optimizer == "adan": try: import timm.optim except ImportError: - raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + raise ImportError( + "To use Adan, please install the PyTorch Image Models library: `pip install timm`." + ) create_optimizer = partial( timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + no_prox=True, ) - elif args.optimizer == 'lion': + elif args.optimizer == "lion": try: import lion_pytorch except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + raise ImportError( + "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." + ) create_optimizer = partial( lion_pytorch.Lion, @@ -570,7 +811,7 @@ def main(): weight_decay=args.adam_weight_decay, use_triton=True, ) - elif args.optimizer == 'adafactor': + elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, @@ -580,13 +821,16 @@ def main(): ) args.lr_scheduler = "adafactor" - args.lr_min_lr = args.learning_rate - args.learning_rate = None - elif args.optimizer == 'dadam': + args.lr_min_lr = args.learning_rate_unet + args.learning_rate_unet = None + args.learning_rate_text = None + elif args.optimizer == "dadam": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdam, @@ -595,46 +839,65 @@ def main(): eps=args.adam_epsilon, decouple=True, d0=args.dadaptation_d0, + growth_rate=args.dadaptation_growth_rate, ) - args.learning_rate = 1.0 - elif args.optimizer == 'dadan': + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 + elif args.optimizer == "dadan": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, d0=args.dadaptation_d0, + growth_rate=args.dadaptation_growth_rate, ) - args.learning_rate = 1.0 + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 + elif args.optimizer == "dlion": + raise ImportError("DLion has not been merged into dadaptation yet") else: - raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") + raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, + tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, + schedule_sampler=schedule_sampler, + min_snr_gamma=args.min_snr_gamma, dtype=weight_dtype, + seed=args.seed, + compile_unet=args.compile_unet, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, - no_val=args.valid_set_size == 0, + sample_scheduler=sample_scheduler, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + max_grad_norm=args.max_grad_norm, ) - checkpoint_output_dir = output_dir / "model" - sample_output_dir = output_dir / "samples" + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + data_npgenerator = np.random.default_rng(args.seed) - datamodule = VlpnDataModule( + create_datamodule = partial( + VlpnDataModule, data_file=args.train_data_file, - batch_size=args.train_batch_size, tokenizer=tokenizer, + constant_prompt_length=args.compile_unet, class_subdir=args.class_image_dir, with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, @@ -643,83 +906,186 @@ def main(): progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, - valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - seed=args.seed, - filter=partial(keyword_filter, None, args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() - - num_train_epochs = args.num_train_epochs - sample_frequency = args.sample_frequency - if num_train_epochs is None: - num_train_epochs = math.ceil( - args.num_train_steps / len(datamodule.train_dataset) - ) * args.gradient_accumulation_steps - sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) - - params_to_optimize = (unet.parameters(), ) - if args.train_text_encoder_epochs != 0: - params_to_optimize += ( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - ) - - optimizer = create_optimizer( - itertools.chain(*params_to_optimize), - lr=args.learning_rate, + dtype=weight_dtype, + generator=data_generator, + npgenerator=data_npgenerator, ) - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, + create_lr_scheduler = partial( + get_scheduler, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, end_lr=1e2, - train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, mid_point=args.lr_mid_point, ) - trainer( - strategy=dreambooth_strategy, - project="dreambooth", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - seed=args.seed, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=num_train_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=sample_frequency, - offset_noise_strength=args.offset_noise_strength, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - sample_output_dir=sample_output_dir, - checkpoint_output_dir=checkpoint_output_dir, - train_text_encoder_epochs=args.train_text_encoder_epochs, - max_grad_norm=args.max_grad_norm, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, + # Dreambooth + # -------------------------------------------------------------------------------- + + dreambooth_datamodule = create_datamodule( + valid_set_size=args.valid_set_size, + batch_size=args.train_batch_size, + dropout=args.tag_dropout, + filter=partial(keyword_filter, None, args.collection, args.exclude_collections), + ) + dreambooth_datamodule.setup() + + num_train_epochs = args.num_train_epochs + dreambooth_sample_frequency = args.sample_frequency + if num_train_epochs is None: + num_train_epochs = ( + math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) + * args.gradient_accumulation_steps + ) + dreambooth_sample_frequency = math.ceil( + num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) + ) + num_training_steps_per_epoch = math.ceil( + len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps ) + num_train_steps = num_training_steps_per_epoch * num_train_epochs + if args.sample_num is not None: + dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) + + dreambooth_project = "dreambooth" + + if accelerator.is_main_process: + accelerator.init_trackers(dreambooth_project) + + dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" + + training_iter = 0 + auto_cycles = list(args.auto_cycles) + learning_rate_unet = args.learning_rate_unet + learning_rate_text = args.learning_rate_text + lr_scheduler = args.lr_scheduler + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + + avg_loss = AverageMeter() + avg_acc = AverageMeter() + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + + params_to_optimize = [ + { + "params": (param for param in unet.parameters() if param.requires_grad), + "lr": learning_rate_unet, + }, + { + "params": ( + param for param in text_encoder.parameters() if param.requires_grad + ), + "lr": learning_rate_text, + }, + ] + group_labels = ["unet", "text"] + + dreambooth_optimizer = create_optimizer(params_to_optimize) + + while True: + if len(auto_cycles) != 0: + response = auto_cycles.pop(0) + else: + response = input( + "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " + ) + + if response.lower().strip() == "o": + if args.learning_rate_unet is not None: + learning_rate_unet = ( + args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) + ) + if args.learning_rate_text is not None: + learning_rate_text = ( + args.learning_rate_text * 2 * (args.cycle_decay**training_iter) + ) + else: + learning_rate_unet = args.learning_rate_unet * ( + args.cycle_decay**training_iter + ) + learning_rate_text = args.learning_rate_text * ( + args.cycle_decay**training_iter + ) + + if response.lower().strip() == "o": + lr_scheduler = "one_cycle" + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + elif response.lower().strip() == "w": + lr_scheduler = "constant_with_warmup" + lr_warmup_epochs = num_train_epochs + elif response.lower().strip() == "c": + lr_scheduler = "constant" + elif response.lower().strip() == "d": + lr_scheduler = "cosine" + lr_warmup_epochs = 0 + lr_cycles = 1 + elif response.lower().strip() == "s": + break + else: + continue + + print("") + print( + f"============ Dreambooth cycle {training_iter + 1}: {response} ============" + ) + print("") + + for group, lr in zip( + dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] + ): + group["lr"] = lr + + dreambooth_lr_scheduler = create_lr_scheduler( + lr_scheduler, + gradient_accumulation_steps=args.gradient_accumulation_steps, + optimizer=dreambooth_optimizer, + num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), + train_epochs=num_train_epochs, + cycles=lr_cycles, + warmup_epochs=lr_warmup_epochs, + ) + + dreambooth_checkpoint_output_dir = ( + output_dir / dreambooth_project / f"model_{training_iter}" + ) + + trainer( + strategy=dreambooth_strategy, + train_dataloader=dreambooth_datamodule.train_dataloader, + val_dataloader=dreambooth_datamodule.val_dataloader, + optimizer=dreambooth_optimizer, + lr_scheduler=dreambooth_lr_scheduler, + num_train_epochs=num_train_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + global_step_offset=training_iter * num_train_steps, + cycle=training_iter, + train_text_encoder_cycles=args.train_text_encoder_cycles, + # -- + group_labels=group_labels, + sample_output_dir=dreambooth_sample_output_dir, + checkpoint_output_dir=dreambooth_checkpoint_output_dir, + sample_frequency=dreambooth_sample_frequency, + offset_noise_strength=args.offset_noise_strength, + input_pertubation=args.input_pertubation, + no_val=args.valid_set_size == 0, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, + ) + + training_iter += 1 + + accelerator.end_training() if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf