From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- train_ti.py | 479 +++++++++--------------------------------------------------- 1 file changed, 70 insertions(+), 409 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 3f4e739..3a55f40 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,31 +1,15 @@ import argparse -import math -import datetime -import logging -from functools import partial -from pathlib import Path -from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint -from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel -import matplotlib.pyplot as plt -from tqdm.auto import tqdm -from transformers import CLIPTextModel -from slugify import slugify - -from util import load_config, load_embeddings_from_dir -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, train_loop, generate_class_images, get_scheduler -from training.lr import LRFinder -from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args -from models.clip.embeddings import patch_managed_embeddings -from models.clip.tokenizer import MultiCLIPTokenizer + +from util import load_config +from data.csv import VlpnDataItem +from training.common import train_setup +from training.modules.ti import train_ti +from training.util import save_args logger = get_logger(__name__) @@ -271,7 +255,7 @@ def parse_args(): parser.add_argument( "--lr_min_lr", type=float, - default=None, + default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( @@ -401,19 +385,19 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--decay_target", - default=None, + "--emb_decay_target", + default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( - "--decay_factor", + "--emb_decay_factor", default=1, type=float, help="Embedding decay factor." ) parser.add_argument( - "--decay_start", + "--emb_decay_start", default=1e-4, type=float, help="Embedding decay start offset." @@ -491,213 +475,10 @@ def parse_args(): return args -class Checkpointer(CheckpointerBase): - def __init__( - self, - weight_dtype, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, - ema_embeddings: EMAModel, - scheduler, - placeholder_token, - new_ids, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.ema_embeddings = ema_embeddings - self.scheduler = scheduler - self.placeholder_token = placeholder_token - self.new_ids = new_ids - - @torch.no_grad() - def checkpoint(self, step, postfix): - print("Saving checkpoint for step %d..." % step) - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - for (token, ids) in zip(self.placeholder_token, self.new_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - del text_encoder - - @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - - text_encoder.to(dtype=orig_dtype) - - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - def main(): args = parse_args() - global_step_offset = args.global_step - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) - - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision - ) - - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - - args.seed = args.seed or (torch.random.seed() >> 32) - set_seed(args.seed) - - save_args(basepath, args) - - # Load the tokenizer and add the placeholder token as a additional special token - if args.tokenizer_name: - tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) - elif args.pretrained_model_name_or_path: - tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - tokenizer.set_dropout(args.vector_dropout) - - # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') - checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder='scheduler') - - vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - embeddings = patch_managed_embeddings(text_encoder) - ema_embeddings = None - - if args.embeddings_dir is not None: - embeddings_dir = Path(args.embeddings_dir) - if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - raise ValueError("--embeddings_dir must point to an existing directory") - - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = [ - tokenizer.encode(token, add_special_tokens=False) - for token in args.initializer_token - ] - - new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) - embeddings.resize(len(tokenizer)) - - for (new_id, init_ids) in zip(new_ids, initializer_token_ids): - embeddings.add_embed(new_id, init_ids) - - init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] - - print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") - - if args.use_ema: - ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - - vae.requires_grad_(False) - unet.requires_grad_(False) - - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes - ) - - if args.find_lr: - args.learning_rate = 1e-5 - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # Initialize the optimizer - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) - - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - def keyword_filter(item: VlpnDataItem): + def data_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_token @@ -710,198 +491,78 @@ def main(): ) return cond1 and cond3 and cond4 - datamodule = VlpnDataModule( + setup = train_setup( + output_dir=args.output_dir, + project=args.project, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + learning_rate=args.learning_rate, data_file=args.train_data_file, - batch_size=args.train_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + seed=args.seed, + vector_shuffle=args.vector_shuffle, + vector_dropout=args.vector_dropout, + gradient_checkpointing=args.gradient_checkpointing, + embeddings_dir=args.embeddings_dir, + placeholder_token=args.placeholder_token, + initializer_token=args.initializer_token, + num_vectors=args.num_vectors, + scale_lr=args.scale_lr, + use_8bit_adam=args.use_8bit_adam, + train_batch_size=args.train_batch_size, + class_image_dir=args.class_image_dir, num_class_images=args.num_class_images, - size=args.resolution, + resolution=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, - shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, + tag_dropout=args.tag_dropout, + tag_shuffle=not args.no_tag_shuffle, + data_template=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - num_workers=args.dataloader_num_workers, - seed=args.seed, - filter=keyword_filter, - dtype=weight_dtype - ) - datamodule.setup() - - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - - if args.num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - checkpoint_scheduler, - datamodule.data_train, - args.sample_batch_size, - args.sample_image_size, - args.sample_steps - ) - - if args.find_lr: - lr_scheduler = None - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - min_lr=args.lr_min_lr, - lr=args.learning_rate, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - train_epochs=args.num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - num_training_steps_per_epoch=len(train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps - ) - - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + data_filter=data_filter, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_steps=args.sample_steps, ) - # Move vae and unet to device - vae.to(accelerator.device, dtype=weight_dtype) - unet.to(accelerator.device, dtype=weight_dtype) - - if args.use_ema: - ema_embeddings.to(accelerator.device) + save_args(setup.output_dir, args) - # Keep vae and unet in eval mode as we don't train these - vae.eval() - - if args.gradient_checkpointing: - unet.train() - else: - unet.eval() - - @contextmanager - def on_train(): - try: - tokenizer.train() - yield - finally: - pass - - @contextmanager - def on_eval(): - try: - tokenizer.eval() - - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() - - with ema_context: - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - args.decay_target, - min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) - ) - - if args.use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(): - if args.use_ema: - return {"ema_decay": ema_embeddings.decay} - return {} - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.num_class_images != 0, - args.prior_loss_weight, - args.seed, - ) - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - datamodule=datamodule, - accelerator=accelerator, - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - scheduler=checkpoint_scheduler, - placeholder_token=args.placeholder_token, - new_ids=new_ids, - output_dir=basepath, + train_ti( + setup=setup, + num_train_epochs=args.num_train_epochs, + num_class_images=args.num_class_images, + prior_loss_weight=args.prior_loss_weight, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + adam_beta1=args.adam_beta1, + adam_beta2=args.adam_beta2, + adam_weight_decay=args.adam_weight_decay, + adam_epsilon=args.adam_epsilon, + adam_amsgrad=args.adam_amsgrad, + lr_scheduler=args.lr_scheduler, + lr_min_lr=args.lr_min_lr, + lr_warmup_func=args.lr_warmup_func, + lr_annealing_func=args.lr_annealing_func, + lr_warmup_exp=args.lr_warmup_exp, + lr_annealing_exp=args.lr_annealing_exp, + lr_cycles=args.lr_cycles, + lr_warmup_epochs=args.lr_warmup_epochs, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - seed=args.seed - ) - - if accelerator.is_main_process: - config = vars(args).copy() - config["initializer_token"] = " ".join(config["initializer_token"]) - config["placeholder_token"] = " ".join(config["placeholder_token"]) - config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) - if config["collection"] is not None: - config["collection"] = " ".join(config["collection"]) - if config["exclude_collections"] is not None: - config["exclude_collections"] = " ".join(config["exclude_collections"]) - accelerator.init_trackers("textual_inversion", config=config) - - if args.find_lr: - lr_finder = LRFinder( - accelerator=accelerator, - optimizer=optimizer, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - on_train=on_train, - on_eval=on_eval, - on_after_optimize=on_after_optimize, - ) - lr_finder.run(num_epochs=100, end_lr=1e3) - - plt.savefig(basepath.joinpath("lr.png"), dpi=300) - plt.close() - else: - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - checkpointer=checkpointer, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - gradient_accumulation_steps=args.gradient_accumulation_steps, - num_epochs=args.num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=args.global_step, + ) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf