From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- models/clip/embeddings.py | 6 +- train_dreambooth.py | 272 ++++++----------------- train_ti.py | 479 ++++++----------------------------------- training/common.py | 260 ++++++++++++++++++++-- training/lr.py | 14 +- training/modules/dreambooth.py | 0 training/modules/lora.py | 0 training/modules/ti.py | 284 ++++++++++++++++++++++++ training/util.py | 15 +- 9 files changed, 677 insertions(+), 653 deletions(-) create mode 100644 training/modules/dreambooth.py create mode 100644 training/modules/lora.py create mode 100644 training/modules/ti.py diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 761efbc..9a23a2a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -40,8 +40,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.position_embedding = embeddings.position_embedding self.initializer_factor = config.initializer_factor - self.decay_target = self.token_embedding.weight[:, :].norm(dim=-1, keepdim=True).median().item() - self.temp_token_embedding = nn.Embedding( self.token_embedding.num_embeddings, self.token_embedding.embedding_dim, @@ -101,9 +99,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeds - def normalize(self, target: Optional[float] = None, lambda_: float = 1.0): - if target is None: - target = self.decay_target + def normalize(self, target: float = 0.4, lambda_: float = 1.0): w = self.temp_token_embedding.weight pre_norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) w[self.temp_token_ids] = F.normalize( diff --git a/train_dreambooth.py b/train_dreambooth.py index fbbe6c2..c892ebf 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -1,6 +1,5 @@ import argparse import itertools -import math import datetime import logging from pathlib import Path @@ -16,16 +15,15 @@ from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel import matplotlib.pyplot as plt from diffusers.training_utils import EMAModel -from tqdm.auto import tqdm from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, generate_class_images, get_scheduler +from training.common import loss_step, train_loop, generate_class_images, get_scheduler from training.lr import LRFinder -from training.util import AverageMeter, CheckpointerBase, save_args +from training.util import CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer @@ -292,7 +290,7 @@ def parse_args(): parser.add_argument( "--lr_min_lr", type=float, - default=None, + default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( @@ -787,14 +785,6 @@ def main(): args.sample_steps ) - # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - if args.find_lr: lr_scheduler = None else: @@ -802,15 +792,14 @@ def main(): args.lr_scheduler, optimizer=optimizer, min_lr=args.lr_min_lr, - lr=args.learning_rate, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, warmup_epochs=args.lr_warmup_epochs, - max_train_steps=args.max_train_steps, - num_update_steps_per_epoch=num_update_steps_per_epoch, + num_training_steps_per_epoch=len(train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps ) @@ -827,19 +816,16 @@ def main(): if args.use_ema: ema_unet.to(accelerator.device) - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - - num_val_steps_per_epoch = len(val_dataloader) - num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - val_steps = num_val_steps_per_epoch * num_epochs - @contextmanager - def on_train(): + def on_train(epoch: int): try: tokenizer.train() + + if epoch < args.train_text_encoder_epochs: + text_encoder.train() + elif epoch == args.train_text_encoder_epochs: + text_encoder.requires_grad_(False) + yield finally: pass @@ -848,6 +834,7 @@ def main(): def on_eval(): try: tokenizer.eval() + text_encoder.eval() ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() @@ -856,7 +843,7 @@ def main(): finally: pass - def on_before_optimize(): + def on_before_optimize(epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if args.train_text_encoder and epoch < args.train_text_encoder_epochs: @@ -866,9 +853,17 @@ def main(): @torch.no_grad() def on_after_optimize(lr: float): if not args.train_text_encoder: - text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) + text_encoder.text_model.embeddings.normalize( + args.decay_target, + min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) + ) + + def on_log(): + if args.use_ema: + return {"ema_decay": ema_unet.decay} + return {} - loop = partial( + loss_step_ = partial( loss_step, vae, noise_scheduler, @@ -879,8 +874,25 @@ def main(): args.seed, ) - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. + checkpointer = Checkpointer( + weight_dtype=weight_dtype, + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + ema_unet=ema_unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=checkpoint_scheduler, + output_dir=basepath, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + if accelerator.is_main_process: config = vars(args).copy() config["initializer_token"] = " ".join(config["initializer_token"]) @@ -898,9 +910,9 @@ def main(): optimizer, train_dataloader, val_dataloader, - loop, - on_train=tokenizer.train, - on_eval=tokenizer.eval, + loss_step_, + on_train=on_train, + on_eval=on_eval, on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, ) @@ -909,182 +921,28 @@ def main(): plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() - quit() - - # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num Epochs = {num_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - - global_step = 0 - - avg_loss = AverageMeter() - avg_acc = AverageMeter() + return - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - max_acc_val = 0.0 - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - datamodule=datamodule, + train_loop( accelerator=accelerator, - vae=vae, - unet=unet, - ema_unet=ema_unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - scheduler=checkpoint_scheduler, - output_dir=basepath, - placeholder_token=args.placeholder_token, - placeholder_token_id=placeholder_token_id, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - - local_progress_bar = tqdm( - range(num_update_steps_per_epoch + num_val_steps_per_epoch), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True - ) - local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") - - global_progress_bar = tqdm( - range(args.max_train_steps + val_steps), - disable=not accelerator.is_local_main_process, - dynamic_ncols=True + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=unet, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=0, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval ) - global_progress_bar.set_description("Total progress") - - try: - for epoch in range(num_epochs): - if accelerator.is_main_process: - if epoch % args.sample_frequency == 0: - checkpointer.save_samples(global_step, args.sample_steps) - - local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") - local_progress_bar.reset() - - unet.train() - if epoch < args.train_text_encoder_epochs: - text_encoder.train() - elif epoch == args.train_text_encoder_epochs: - text_encoder.requires_grad_(False) - - with on_train(): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - loss, acc, bsz = loop(step, batch) - - accelerator.backward(loss) - - on_before_optimize() - - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - if args.use_ema: - ema_unet.step(unet.parameters()) - optimizer.zero_grad(set_to_none=True) - - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - on_after_optimize(lr_scheduler.get_last_lr()[0]) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - global_step += 1 - - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0] - } - if args.use_ema: - logs["ema_decay"] = 1 - ema_unet.decay - - accelerator.log(logs, step=global_step) - - local_progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - accelerator.wait_for_everyone() - - unet.eval() - text_encoder.eval() - - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() - - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(step, batch, True) - - loss = loss.detach_() - acc = acc.detach_() - - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) - - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) - - local_progress_bar.update(1) - global_progress_bar.update(1) - - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) - - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() - - accelerator.log(logs, step=global_step) - - local_progress_bar.clear() - global_progress_bar.clear() - - if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - max_acc_val = avg_acc_val.avg.item() - - # Create the pipeline using using the trained modules and save it. - if accelerator.is_main_process: - print("Finished! Saving final checkpoint and resume state.") - checkpointer.save_samples(global_step, args.sample_steps) - checkpointer.save_model() - accelerator.end_training() - - except KeyboardInterrupt: - if accelerator.is_main_process: - print("Interrupted, saving checkpoint and resume state...") - checkpointer.save_model() - accelerator.end_training() - quit() if __name__ == "__main__": diff --git a/train_ti.py b/train_ti.py index 3f4e739..3a55f40 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,31 +1,15 @@ import argparse -import math -import datetime -import logging -from functools import partial -from pathlib import Path -from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint -from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel -import matplotlib.pyplot as plt -from tqdm.auto import tqdm -from transformers import CLIPTextModel -from slugify import slugify - -from util import load_config, load_embeddings_from_dir -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, train_loop, generate_class_images, get_scheduler -from training.lr import LRFinder -from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args -from models.clip.embeddings import patch_managed_embeddings -from models.clip.tokenizer import MultiCLIPTokenizer + +from util import load_config +from data.csv import VlpnDataItem +from training.common import train_setup +from training.modules.ti import train_ti +from training.util import save_args logger = get_logger(__name__) @@ -271,7 +255,7 @@ def parse_args(): parser.add_argument( "--lr_min_lr", type=float, - default=None, + default=0.04, help="Minimum learning rate in the lr scheduler." ) parser.add_argument( @@ -401,19 +385,19 @@ def parse_args(): help="The weight of prior preservation loss." ) parser.add_argument( - "--decay_target", - default=None, + "--emb_decay_target", + default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( - "--decay_factor", + "--emb_decay_factor", default=1, type=float, help="Embedding decay factor." ) parser.add_argument( - "--decay_start", + "--emb_decay_start", default=1e-4, type=float, help="Embedding decay start offset." @@ -491,213 +475,10 @@ def parse_args(): return args -class Checkpointer(CheckpointerBase): - def __init__( - self, - weight_dtype, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, - ema_embeddings: EMAModel, - scheduler, - placeholder_token, - new_ids, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.ema_embeddings = ema_embeddings - self.scheduler = scheduler - self.placeholder_token = placeholder_token - self.new_ids = new_ids - - @torch.no_grad() - def checkpoint(self, step, postfix): - print("Saving checkpoint for step %d..." % step) - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - for (token, ids) in zip(self.placeholder_token, self.new_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - del text_encoder - - @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() - - with ema_context: - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - - text_encoder.to(dtype=orig_dtype) - - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - def main(): args = parse_args() - global_step_offset = args.global_step - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) - - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision - ) - - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - - args.seed = args.seed or (torch.random.seed() >> 32) - set_seed(args.seed) - - save_args(basepath, args) - - # Load the tokenizer and add the placeholder token as a additional special token - if args.tokenizer_name: - tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) - elif args.pretrained_model_name_or_path: - tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - tokenizer.set_dropout(args.vector_dropout) - - # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') - checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder='scheduler') - - vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.set_use_memory_efficient_attention_xformers(True) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - - embeddings = patch_managed_embeddings(text_encoder) - ema_embeddings = None - - if args.embeddings_dir is not None: - embeddings_dir = Path(args.embeddings_dir) - if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - raise ValueError("--embeddings_dir must point to an existing directory") - - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = [ - tokenizer.encode(token, add_special_tokens=False) - for token in args.initializer_token - ] - - new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) - embeddings.resize(len(tokenizer)) - - for (new_id, init_ids) in zip(new_ids, initializer_token_ids): - embeddings.add_embed(new_id, init_ids) - - init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] - - print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") - - if args.use_ema: - ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - - vae.requires_grad_(False) - unet.requires_grad_(False) - - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes - ) - - if args.find_lr: - args.learning_rate = 1e-5 - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # Initialize the optimizer - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) - - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - def keyword_filter(item: VlpnDataItem): + def data_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_token @@ -710,198 +491,78 @@ def main(): ) return cond1 and cond3 and cond4 - datamodule = VlpnDataModule( + setup = train_setup( + output_dir=args.output_dir, + project=args.project, + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + learning_rate=args.learning_rate, data_file=args.train_data_file, - batch_size=args.train_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + seed=args.seed, + vector_shuffle=args.vector_shuffle, + vector_dropout=args.vector_dropout, + gradient_checkpointing=args.gradient_checkpointing, + embeddings_dir=args.embeddings_dir, + placeholder_token=args.placeholder_token, + initializer_token=args.initializer_token, + num_vectors=args.num_vectors, + scale_lr=args.scale_lr, + use_8bit_adam=args.use_8bit_adam, + train_batch_size=args.train_batch_size, + class_image_dir=args.class_image_dir, num_class_images=args.num_class_images, - size=args.resolution, + resolution=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, - shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, + tag_dropout=args.tag_dropout, + tag_shuffle=not args.no_tag_shuffle, + data_template=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - num_workers=args.dataloader_num_workers, - seed=args.seed, - filter=keyword_filter, - dtype=weight_dtype - ) - datamodule.setup() - - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - - if args.num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - checkpoint_scheduler, - datamodule.data_train, - args.sample_batch_size, - args.sample_image_size, - args.sample_steps - ) - - if args.find_lr: - lr_scheduler = None - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - min_lr=args.lr_min_lr, - lr=args.learning_rate, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - train_epochs=args.num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - num_training_steps_per_epoch=len(train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps - ) - - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + data_filter=data_filter, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_steps=args.sample_steps, ) - # Move vae and unet to device - vae.to(accelerator.device, dtype=weight_dtype) - unet.to(accelerator.device, dtype=weight_dtype) - - if args.use_ema: - ema_embeddings.to(accelerator.device) + save_args(setup.output_dir, args) - # Keep vae and unet in eval mode as we don't train these - vae.eval() - - if args.gradient_checkpointing: - unet.train() - else: - unet.eval() - - @contextmanager - def on_train(): - try: - tokenizer.train() - yield - finally: - pass - - @contextmanager - def on_eval(): - try: - tokenizer.eval() - - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() - - with ema_context: - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - args.decay_target, - min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) - ) - - if args.use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(): - if args.use_ema: - return {"ema_decay": ema_embeddings.decay} - return {} - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.num_class_images != 0, - args.prior_loss_weight, - args.seed, - ) - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - datamodule=datamodule, - accelerator=accelerator, - vae=vae, - unet=unet, - tokenizer=tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - scheduler=checkpoint_scheduler, - placeholder_token=args.placeholder_token, - new_ids=new_ids, - output_dir=basepath, + train_ti( + setup=setup, + num_train_epochs=args.num_train_epochs, + num_class_images=args.num_class_images, + prior_loss_weight=args.prior_loss_weight, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + adam_beta1=args.adam_beta1, + adam_beta2=args.adam_beta2, + adam_weight_decay=args.adam_weight_decay, + adam_epsilon=args.adam_epsilon, + adam_amsgrad=args.adam_amsgrad, + lr_scheduler=args.lr_scheduler, + lr_min_lr=args.lr_min_lr, + lr_warmup_func=args.lr_warmup_func, + lr_annealing_func=args.lr_annealing_func, + lr_warmup_exp=args.lr_warmup_exp, + lr_annealing_exp=args.lr_annealing_exp, + lr_cycles=args.lr_cycles, + lr_warmup_epochs=args.lr_warmup_epochs, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - seed=args.seed - ) - - if accelerator.is_main_process: - config = vars(args).copy() - config["initializer_token"] = " ".join(config["initializer_token"]) - config["placeholder_token"] = " ".join(config["placeholder_token"]) - config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) - if config["collection"] is not None: - config["collection"] = " ".join(config["collection"]) - if config["exclude_collections"] is not None: - config["exclude_collections"] = " ".join(config["exclude_collections"]) - accelerator.init_trackers("textual_inversion", config=config) - - if args.find_lr: - lr_finder = LRFinder( - accelerator=accelerator, - optimizer=optimizer, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - on_train=on_train, - on_eval=on_eval, - on_after_optimize=on_after_optimize, - ) - lr_finder.run(num_epochs=100, end_lr=1e3) - - plt.savefig(basepath.joinpath("lr.png"), dpi=300) - plt.close() - else: - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - checkpointer=checkpointer, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - gradient_accumulation_steps=args.gradient_accumulation_steps, - num_epochs=args.num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=args.global_step, + ) if __name__ == "__main__": diff --git a/training/common.py b/training/common.py index 180396e..73ce814 100644 --- a/training/common.py +++ b/training/common.py @@ -1,46 +1,77 @@ import math +from pathlib import Path from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union +from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple +import datetime +import logging import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator -from transformers import CLIPTokenizer, CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from accelerate.utils import LoggerType, set_seed +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from tqdm.auto import tqdm +from slugify import slugify +from data.csv import VlpnDataModule, VlpnDataItem +from util import load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.embeddings import patch_managed_embeddings from models.clip.util import get_extended_embeddings +from models.clip.tokenizer import MultiCLIPTokenizer from training.optimization import get_one_cycle_schedule from training.util import AverageMeter, CheckpointerBase +class TrainingSetup(NamedTuple): + accelerator: Accelerator + tokenizer: MultiCLIPTokenizer + text_encoder: CLIPTextModel + vae: AutoencoderKL + unet: UNet2DConditionModel + noise_scheduler: DDPMScheduler + checkpoint_scheduler: DPMSolverMultistepScheduler + optimizer_class: Callable + learning_rate: float + weight_dtype: torch.dtype + output_dir: Path + seed: int + train_dataloader: DataLoader + val_dataloader: DataLoader + placeholder_token: list[str] + placeholder_token_ids: list[list[int]] + + def noop(*args, **kwards): pass +def noop_ctx(*args, **kwards): + return nullcontext() + + def noop_on_log(): return {} def get_scheduler( id: str, - min_lr: float, - lr: float, - warmup_func: str, - annealing_func: str, - warmup_exp: int, - annealing_exp: int, - cycles: int, - train_epochs: int, - warmup_epochs: int, optimizer: torch.optim.Optimizer, num_training_steps_per_epoch: int, gradient_accumulation_steps: int, + min_lr: float = 0.04, + warmup_func: str = "cos", + annealing_func: str = "cos", + warmup_exp: int = 1, + annealing_exp: int = 1, + cycles: int = 1, + train_epochs: int = 100, + warmup_epochs: int = 10, ): num_training_steps_per_epoch = math.ceil( num_training_steps_per_epoch / gradient_accumulation_steps @@ -49,8 +80,6 @@ def get_scheduler( num_warmup_steps = warmup_epochs * num_training_steps_per_epoch if id == "one_cycle": - min_lr = 0.04 if min_lr is None else min_lr / lr - lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=num_training_steps, @@ -133,6 +162,196 @@ def generate_class_images( torch.cuda.empty_cache() +def train_setup( + output_dir: str, + project: str, + pretrained_model_name_or_path: str, + learning_rate: float, + data_file: str, + gradient_accumulation_steps: int = 1, + mixed_precision: Literal["no", "fp16", "bf16"] = "no", + seed: Optional[int] = None, + vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto", + vector_dropout: float = 0.1, + gradient_checkpointing: bool = True, + embeddings_dir: Optional[str] = None, + placeholder_token: list[str] = [], + initializer_token: list[str] = [], + num_vectors: int = 1, + scale_lr: bool = False, + use_8bit_adam: bool = False, + train_batch_size: int = 1, + class_image_dir: Optional[str] = None, + num_class_images: int = 0, + resolution: int = 768, + num_buckets: int = 0, + progressive_buckets: bool = False, + bucket_step_size: int = 64, + bucket_max_pixels: Optional[int] = None, + tag_dropout: float = 0.1, + tag_shuffle: bool = True, + data_template: str = "template", + valid_set_size: Optional[int] = None, + valid_set_repeat: int = 1, + data_filter: Optional[Callable[[VlpnDataItem], bool]] = None, + sample_batch_size: int = 1, + sample_image_size: int = 768, + sample_steps: int = 20, +) -> TrainingSetup: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(output_dir).joinpath(slugify(project), now) + output_dir.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{output_dir}", + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision + ) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + seed = seed or (torch.random.seed() >> 32) + set_seed(seed) + + # Load the tokenizer and add the placeholder token as a additional special token + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer.set_use_vector_shuffle(vector_shuffle) + tokenizer.set_dropout(vector_dropout) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder='scheduler') + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + embeddings = patch_managed_embeddings(text_encoder) + + if embeddings_dir is not None: + embeddings_dir = Path(embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) + for token in initializer_token + ] + + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors) + embeddings.resize(len(tokenizer)) + + for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): + embeddings.add_embed(new_id, init_ids) + + init_ratios = [ + f"{len(init_ids)} / {len(new_id)}" + for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids) + ] + + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}") + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * + train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + weight_dtype = torch.float32 + if mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + datamodule = VlpnDataModule( + data_file=data_file, + batch_size=train_batch_size, + tokenizer=tokenizer, + class_subdir=class_image_dir, + num_class_images=num_class_images, + size=resolution, + num_buckets=num_buckets, + progressive_buckets=progressive_buckets, + bucket_step_size=bucket_step_size, + bucket_max_pixels=bucket_max_pixels, + dropout=tag_dropout, + shuffle=tag_shuffle, + template_key=data_template, + valid_set_size=valid_set_size, + valid_set_repeat=valid_set_repeat, + seed=seed, + filter=data_filter, + dtype=weight_dtype + ) + datamodule.setup() + + train_dataloader = datamodule.train_dataloader + val_dataloader = datamodule.val_dataloader + + train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader) + + if num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + checkpoint_scheduler, + datamodule.data_train, + sample_batch_size, + sample_image_size, + sample_steps + ) + + return TrainingSetup( + accelerator=accelerator, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + noise_scheduler=noise_scheduler, + checkpoint_scheduler=checkpoint_scheduler, + optimizer_class=optimizer_class, + learning_rate=learning_rate, + output_dir=output_dir, + weight_dtype=weight_dtype, + seed=seed, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + placeholder_token=placeholder_token, + placeholder_token_ids=placeholder_token_ids + ) + + def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, @@ -221,15 +440,14 @@ def train_loop( sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, - gradient_accumulation_steps: int = 1, num_epochs: int = 100, on_log: Callable[[], dict[str, Any]] = noop_on_log, - on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_before_optimize: Callable[[], None] = noop, + on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, + on_before_optimize: Callable[[int], None] = noop, on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = nullcontext + on_eval: Callable[[], _GeneratorContextManager] = noop_ctx ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) num_training_steps = num_training_steps_per_epoch * num_epochs @@ -273,14 +491,14 @@ def train_loop( model.train() - with on_train(): + with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - on_before_optimize() + on_before_optimize(epoch) optimizer.step() lr_scheduler.step() diff --git a/training/lr.py b/training/lr.py index 84e30a0..7584ba2 100644 --- a/training/lr.py +++ b/training/lr.py @@ -16,6 +16,10 @@ def noop(*args, **kwards): pass +def noop_ctx(*args, **kwards): + return nullcontext() + + class LRFinder(): def __init__( self, @@ -25,10 +29,10 @@ class LRFinder(): train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_before_optimize: Callable[[], None] = noop, + on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, + on_before_optimize: Callable[[int], None] = noop, on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = nullcontext + on_eval: Callable[[], _GeneratorContextManager] = noop_ctx ): self.accelerator = accelerator self.model = model @@ -86,7 +90,7 @@ class LRFinder(): self.model.train() - with self.on_train(): + with self.on_train(epoch): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break @@ -96,7 +100,7 @@ class LRFinder(): self.accelerator.backward(loss) - self.on_before_optimize() + self.on_before_optimize(epoch) self.optimizer.step() lr_scheduler.step() diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py new file mode 100644 index 0000000..e69de29 diff --git a/training/modules/lora.py b/training/modules/lora.py new file mode 100644 index 0000000..e69de29 diff --git a/training/modules/ti.py b/training/modules/ti.py new file mode 100644 index 0000000..2db6f88 --- /dev/null +++ b/training/modules/ti.py @@ -0,0 +1,284 @@ +from typing import Literal +from functools import partial +from contextlib import contextmanager, nullcontext + +import torch + +from slugify import slugify + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel + +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.tokenizer import MultiCLIPTokenizer + +from training.common import TrainingSetup, get_scheduler, train_loop, loss_step +from training.util import EMAModel, CheckpointerBase + + +class Checkpointer(CheckpointerBase): + def __init__( + self, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, + weight_dtype: torch.dtype, + scheduler, + placeholder_token, + placeholder_token_ids, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.weight_dtype = weight_dtype + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings + self.scheduler = scheduler + self.placeholder_token = placeholder_token + self.placeholder_token_ids = placeholder_token_ids + + @torch.no_grad() + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = nullcontext() + if self.ema_embeddings is not None: + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + del text_encoder + + @torch.no_grad() + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = nullcontext() + if self.ema_embeddings is not None: + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + + text_encoder.to(dtype=orig_dtype) + + del text_encoder + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def train_ti( + setup: TrainingSetup, + num_train_epochs: int = 100, + num_class_images: int = 0, + prior_loss_weight: float = 1.0, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: float = 4/5, + ema_max_decay: float = .9999, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 0, + adam_epsilon: float = 1e-08, + adam_amsgrad: bool = False, + lr_scheduler: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" + ] = "one_cycle", + lr_min_lr: float = 0.04, + lr_warmup_func: Literal["linear", "cos"] = "cos", + lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", + lr_warmup_exp: int = 1, + lr_annealing_exp: int = 1, + lr_cycles: int = 1, + lr_warmup_epochs: int = 10, + emb_decay_target: float = 0.4, + emb_decay_factor: float = 1, + emb_decay_start: float = 1e-4, + sample_image_size: int = 768, + sample_batch_size: int = 1, + sample_batches: int = 1, + sample_frequency: int = 10, + sample_steps: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, +): + if use_ema: + ema_embeddings = EMAModel( + setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + else: + ema_embeddings = None + + setup.text_encoder.requires_grad_(True) + setup.text_encoder.text_model.encoder.requires_grad_(False) + setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) + setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) + + # Initialize the optimizer + optimizer = setup.optimizer_class( + setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=setup.learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + amsgrad=adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + min_lr=lr_min_lr, + warmup_func=lr_warmup_func, + annealing_func=lr_annealing_func, + warmup_exp=lr_warmup_exp, + annealing_exp=lr_annealing_exp, + cycles=lr_cycles, + train_epochs=num_train_epochs, + warmup_epochs=lr_warmup_epochs, + num_training_steps_per_epoch=len(setup.train_dataloader), + gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps + ) + + text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( + setup.text_encoder, optimizer, lr_scheduler + ) + + # Move vae and unet to device + setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) + setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) + + if use_ema: + ema_embeddings.to(setup.accelerator.device) + + setup.unet.train() + + @contextmanager + def on_train(epoch: int): + try: + setup.tokenizer.train() + yield + finally: + pass + + @contextmanager + def on_eval(): + try: + setup.tokenizer.eval() + + ema_context = nullcontext() + if use_ema: + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(lr: float): + text_encoder.text_model.embeddings.normalize( + emb_decay_target, + min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) + ) + + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + + loss_step_ = partial( + loss_step, + setup.vae, + setup.noise_scheduler, + setup.unet, + text_encoder, + num_class_images != 0, + prior_loss_weight, + setup.seed, + ) + + checkpointer = Checkpointer( + accelerator=setup.accelerator, + vae=setup.vae, + unet=setup.unet, + tokenizer=setup.tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + weight_dtype=setup.weight_dtype, + scheduler=setup.checkpoint_scheduler, + placeholder_token=setup.placeholder_token, + placeholder_token_ids=setup.placeholder_token_ids, + train_dataloader=setup.train_dataloader, + val_dataloader=setup.val_dataloader, + output_dir=setup.output_dir, + seed=setup.seed, + sample_image_size=sample_image_size, + sample_batch_size=sample_batch_size, + sample_batches=sample_batches + ) + + if setup.accelerator.is_main_process: + setup.accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=setup.accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=setup.train_dataloader, + val_dataloader=setup.val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + sample_steps=sample_steps, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) diff --git a/training/util.py b/training/util.py index 0ec2032..cc4cdee 100644 --- a/training/util.py +++ b/training/util.py @@ -41,14 +41,16 @@ class AverageMeter: class CheckpointerBase: def __init__( self, - datamodule, + train_dataloader, + val_dataloader, output_dir: Path, sample_image_size: int, sample_batches: int, sample_batch_size: int, seed: Optional[int] = None ): - self.datamodule = datamodule + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed if seed is not None else torch.random.seed() @@ -70,15 +72,16 @@ class CheckpointerBase: ): samples_path = Path(self.output_dir).joinpath("samples") - train_data = self.datamodule.train_dataloader - val_data = self.datamodule.val_dataloader - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols - for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: + for pool, data, gen in [ + ("stable", self.val_dataloader, generator), + ("val", self.val_dataloader, None), + ("train", self.train_dataloader, None) + ]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) -- cgit v1.2.3-70-g09d2