From 3e7fbb7dce321435bbbb81361debfbc499bf9231 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 22:25:30 +0100 Subject: Reverted modularization mostly --- train_dreambooth.py | 3 +- train_ti.py | 467 ++++++++++++++++++++++++++++++++++------- training/common.py | 264 ++--------------------- training/modules/dreambooth.py | 0 training/modules/lora.py | 0 training/modules/ti.py | 284 ------------------------- training/optimization.py | 53 +++++ 7 files changed, 458 insertions(+), 613 deletions(-) delete mode 100644 training/modules/dreambooth.py delete mode 100644 training/modules/lora.py delete mode 100644 training/modules/ti.py diff --git a/train_dreambooth.py b/train_dreambooth.py index c892ebf..2145e2b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -21,7 +21,8 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, train_loop, generate_class_images, get_scheduler +from training.common import loss_step, train_loop, generate_class_images +from training.optimization import get_scheduler from training.lr import LRFinder from training.util import CheckpointerBase, save_args from models.clip.embeddings import patch_managed_embeddings diff --git a/train_ti.py b/train_ti.py index 3a55f40..61195f6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,15 +1,29 @@ import argparse +import datetime +import logging +from functools import partial +from pathlib import Path +from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint +from accelerate import Accelerator from accelerate.logging import get_logger - -from util import load_config -from data.csv import VlpnDataItem -from training.common import train_setup -from training.modules.ti import train_ti -from training.util import save_args +from accelerate.utils import LoggerType, set_seed +from diffusers import AutoencoderKL, UNet2DConditionModel +import matplotlib.pyplot as plt +from transformers import CLIPTextModel +from slugify import slugify + +from util import load_config, load_embeddings_from_dir +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from data.csv import VlpnDataModule, VlpnDataItem +from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models +from training.optimization import get_scheduler +from training.lr import LRFinder +from training.util import CheckpointerBase, EMAModel, save_args +from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -52,13 +66,13 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--placeholder_token", + "--placeholder_tokens", type=str, nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--initializer_token", + "--initializer_tokens", type=str, nargs='*', help="A token to use as initializer word." @@ -439,29 +453,29 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.placeholder_token, str): - args.placeholder_token = [args.placeholder_token] + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] - if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - if len(args.initializer_token) == 0: - raise ValueError("You must specify --initializer_token") + if len(args.initializer_tokens) == 0: + raise ValueError("You must specify --initializer_tokens") - if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): - args.num_vectors = [args.num_vectors] * len(args.initializer_token) + args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) - if len(args.placeholder_token) != len(args.num_vectors): - raise ValueError("--placeholder_token and --num_vectors must have the same number of items") + if len(args.placeholder_tokens) != len(args.num_vectors): + raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -475,13 +489,197 @@ def parse_args(): return args +class Checkpointer(CheckpointerBase): + def __init__( + self, + weight_dtype, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, + scheduler, + placeholder_tokens, + placeholder_token_ids, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.weight_dtype = weight_dtype + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings + self.scheduler = scheduler + self.placeholder_tokens = placeholder_tokens + self.placeholder_token_ids = placeholder_token_ids + + @torch.no_grad() + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() + + with ema_context: + for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + del text_encoder + + @torch.no_grad() + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() + + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + + text_encoder.to(dtype=orig_dtype) + + del text_encoder + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def main(): args = parse_args() - def data_filter(item: VlpnDataItem): + global_step_offset = args.global_step + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.output_dir).joinpath(slugify(args.project), now) + basepath.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + + args.seed = args.seed or (torch.random.seed() >> 32) + set_seed(args.seed) + + save_args(basepath, args) + + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( + args.pretrained_model_name_or_path) + + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + placeholder_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors + ) + + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + + if args.use_ema: + ema_embeddings = EMAModel( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) + else: + ema_embeddings = None + + vae.requires_grad_(False) + unet.requires_grad_(False) + + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + if args.find_lr: + args.learning_rate = 1e-5 + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + def keyword_filter(item: VlpnDataItem): cond1 = any( keyword in part - for keyword in args.placeholder_token + for keyword in args.placeholder_tokens for part in item.prompt ) cond3 = args.collection is None or args.collection in item.collection @@ -491,78 +689,185 @@ def main(): ) return cond1 and cond3 and cond4 - setup = train_setup( - output_dir=args.output_dir, - project=args.project, - pretrained_model_name_or_path=args.pretrained_model_name_or_path, - learning_rate=args.learning_rate, + datamodule = VlpnDataModule( data_file=args.train_data_file, - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - seed=args.seed, - vector_shuffle=args.vector_shuffle, - vector_dropout=args.vector_dropout, - gradient_checkpointing=args.gradient_checkpointing, - embeddings_dir=args.embeddings_dir, - placeholder_token=args.placeholder_token, - initializer_token=args.initializer_token, - num_vectors=args.num_vectors, - scale_lr=args.scale_lr, - use_8bit_adam=args.use_8bit_adam, - train_batch_size=args.train_batch_size, - class_image_dir=args.class_image_dir, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, num_class_images=args.num_class_images, - resolution=args.resolution, + size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - tag_dropout=args.tag_dropout, - tag_shuffle=not args.no_tag_shuffle, - data_template=args.train_data_template, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - data_filter=data_filter, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_steps=args.sample_steps, + num_workers=args.dataloader_num_workers, + seed=args.seed, + filter=keyword_filter, + dtype=weight_dtype + ) + datamodule.setup() + + train_dataloader = datamodule.train_dataloader + val_dataloader = datamodule.val_dataloader + + if args.num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + sample_scheduler, + datamodule.data_train, + args.sample_batch_size, + args.sample_image_size, + args.sample_steps + ) + + if args.find_lr: + lr_scheduler = None + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - save_args(setup.output_dir, args) + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) - train_ti( - setup=setup, - num_train_epochs=args.num_train_epochs, - num_class_images=args.num_class_images, - prior_loss_weight=args.prior_loss_weight, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - adam_beta1=args.adam_beta1, - adam_beta2=args.adam_beta2, - adam_weight_decay=args.adam_weight_decay, - adam_epsilon=args.adam_epsilon, - adam_amsgrad=args.adam_amsgrad, - lr_scheduler=args.lr_scheduler, - lr_min_lr=args.lr_min_lr, - lr_warmup_func=args.lr_warmup_func, - lr_annealing_func=args.lr_annealing_func, - lr_warmup_exp=args.lr_warmup_exp, - lr_annealing_exp=args.lr_annealing_exp, - lr_cycles=args.lr_cycles, - lr_warmup_epochs=args.lr_warmup_epochs, - emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, + if args.use_ema: + ema_embeddings.to(accelerator.device) + + if args.gradient_checkpointing: + unet.train() + else: + unet.eval() + + @contextmanager + def on_train(epoch: int): + try: + tokenizer.train() + yield + finally: + pass + + @contextmanager + def on_eval(): + try: + tokenizer.eval() + + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() + + with ema_context: + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(lr: float): + text_encoder.text_model.embeddings.normalize( + args.emb_decay_target, + min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) + ) + + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if args.use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + args.num_class_images != 0, + args.prior_loss_weight, + args.seed, + ) + + checkpointer = Checkpointer( + weight_dtype=weight_dtype, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + scheduler=sample_scheduler, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=args.global_step, - ) + seed=args.seed + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + if args.find_lr: + lr_finder = LRFinder( + accelerator=accelerator, + optimizer=optimizer, + model=text_encoder, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + on_train=on_train, + on_eval=on_eval, + on_after_optimize=on_after_optimize, + ) + lr_finder.run(num_epochs=100, end_lr=1e3) + + plt.savefig(basepath.joinpath("lr.png"), dpi=300) + plt.close() + else: + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) if __name__ == "__main__": diff --git a/training/common.py b/training/common.py index 73ce814..b6964a3 100644 --- a/training/common.py +++ b/training/common.py @@ -1,52 +1,24 @@ import math -from pathlib import Path from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple -import datetime -import logging +from typing import Callable, Any, Tuple, Union import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator -from accelerate.utils import LoggerType, set_seed from transformers import CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler -from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from tqdm.auto import tqdm -from slugify import slugify -from data.csv import VlpnDataModule, VlpnDataItem -from util import load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.embeddings import patch_managed_embeddings +from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer -from training.optimization import get_one_cycle_schedule from training.util import AverageMeter, CheckpointerBase -class TrainingSetup(NamedTuple): - accelerator: Accelerator - tokenizer: MultiCLIPTokenizer - text_encoder: CLIPTextModel - vae: AutoencoderKL - unet: UNet2DConditionModel - noise_scheduler: DDPMScheduler - checkpoint_scheduler: DPMSolverMultistepScheduler - optimizer_class: Callable - learning_rate: float - weight_dtype: torch.dtype - output_dir: Path - seed: int - train_dataloader: DataLoader - val_dataloader: DataLoader - placeholder_token: list[str] - placeholder_token_ids: list[list[int]] - - def noop(*args, **kwards): pass @@ -59,57 +31,6 @@ def noop_on_log(): return {} -def get_scheduler( - id: str, - optimizer: torch.optim.Optimizer, - num_training_steps_per_epoch: int, - gradient_accumulation_steps: int, - min_lr: float = 0.04, - warmup_func: str = "cos", - annealing_func: str = "cos", - warmup_exp: int = 1, - annealing_exp: int = 1, - cycles: int = 1, - train_epochs: int = 100, - warmup_epochs: int = 10, -): - num_training_steps_per_epoch = math.ceil( - num_training_steps_per_epoch / gradient_accumulation_steps - ) * gradient_accumulation_steps - num_training_steps = train_epochs * num_training_steps_per_epoch - num_warmup_steps = warmup_epochs * num_training_steps_per_epoch - - if id == "one_cycle": - lr_scheduler = get_one_cycle_schedule( - optimizer=optimizer, - num_training_steps=num_training_steps, - warmup=warmup_func, - annealing=annealing_func, - warmup_exp=warmup_exp, - annealing_exp=annealing_exp, - min_lr=min_lr, - ) - elif id == "cosine_with_restarts": - if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) - - lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - num_cycles=cycles, - ) - else: - lr_scheduler = get_scheduler_( - id, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - ) - - return lr_scheduler - - def generate_class_images( accelerator, text_encoder, @@ -162,194 +83,43 @@ def generate_class_images( torch.cuda.empty_cache() -def train_setup( - output_dir: str, - project: str, - pretrained_model_name_or_path: str, - learning_rate: float, - data_file: str, - gradient_accumulation_steps: int = 1, - mixed_precision: Literal["no", "fp16", "bf16"] = "no", - seed: Optional[int] = None, - vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto", - vector_dropout: float = 0.1, - gradient_checkpointing: bool = True, - embeddings_dir: Optional[str] = None, - placeholder_token: list[str] = [], - initializer_token: list[str] = [], - num_vectors: int = 1, - scale_lr: bool = False, - use_8bit_adam: bool = False, - train_batch_size: int = 1, - class_image_dir: Optional[str] = None, - num_class_images: int = 0, - resolution: int = 768, - num_buckets: int = 0, - progressive_buckets: bool = False, - bucket_step_size: int = 64, - bucket_max_pixels: Optional[int] = None, - tag_dropout: float = 0.1, - tag_shuffle: bool = True, - data_template: str = "template", - valid_set_size: Optional[int] = None, - valid_set_repeat: int = 1, - data_filter: Optional[Callable[[VlpnDataItem], bool]] = None, - sample_batch_size: int = 1, - sample_image_size: int = 768, - sample_steps: int = 20, -) -> TrainingSetup: - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(output_dir).joinpath(slugify(project), now) - output_dir.mkdir(parents=True, exist_ok=True) - - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{output_dir}", - gradient_accumulation_steps=gradient_accumulation_steps, - mixed_precision=mixed_precision - ) - - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) - - seed = seed or (torch.random.seed() >> 32) - set_seed(seed) - - # Load the tokenizer and add the placeholder token as a additional special token +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - tokenizer.set_use_vector_shuffle(vector_shuffle) - tokenizer.set_dropout(vector_dropout) - - # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) - if gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - embeddings = patch_managed_embeddings(text_encoder) - if embeddings_dir is not None: - embeddings_dir = Path(embeddings_dir) - if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - raise ValueError("--embeddings_dir must point to an existing directory") + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - # Convert the initializer_token, placeholder_token to ids +def add_placeholder_tokens( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + placeholder_tokens: list[str], + initializer_tokens: list[str], + num_vectors: Union[list[int], int] +): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) - for token in initializer_token + for token in initializer_tokens ] + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) - placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors) embeddings.resize(len(tokenizer)) - for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(new_id, init_ids) - - init_ratios = [ - f"{len(init_ids)} / {len(new_id)}" - for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids) - ] - - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}") + for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): + embeddings.add_embed(placeholder_token_id, initializer_token_id) - vae.requires_grad_(False) - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - - if scale_lr: - learning_rate = ( - learning_rate * gradient_accumulation_steps * - train_batch_size * accelerator.num_processes - ) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - weight_dtype = torch.float32 - if mixed_precision == "fp16": - weight_dtype = torch.float16 - elif mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - datamodule = VlpnDataModule( - data_file=data_file, - batch_size=train_batch_size, - tokenizer=tokenizer, - class_subdir=class_image_dir, - num_class_images=num_class_images, - size=resolution, - num_buckets=num_buckets, - progressive_buckets=progressive_buckets, - bucket_step_size=bucket_step_size, - bucket_max_pixels=bucket_max_pixels, - dropout=tag_dropout, - shuffle=tag_shuffle, - template_key=data_template, - valid_set_size=valid_set_size, - valid_set_repeat=valid_set_repeat, - seed=seed, - filter=data_filter, - dtype=weight_dtype - ) - datamodule.setup() - - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - - train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader) - - if num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - checkpoint_scheduler, - datamodule.data_train, - sample_batch_size, - sample_image_size, - sample_steps - ) - - return TrainingSetup( - accelerator=accelerator, - tokenizer=tokenizer, - text_encoder=text_encoder, - vae=vae, - unet=unet, - noise_scheduler=noise_scheduler, - checkpoint_scheduler=checkpoint_scheduler, - optimizer_class=optimizer_class, - learning_rate=learning_rate, - output_dir=output_dir, - weight_dtype=weight_dtype, - seed=seed, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - placeholder_token=placeholder_token, - placeholder_token_ids=placeholder_token_ids - ) + return placeholder_token_ids def loss_step( diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py deleted file mode 100644 index e69de29..0000000 diff --git a/training/modules/lora.py b/training/modules/lora.py deleted file mode 100644 index e69de29..0000000 diff --git a/training/modules/ti.py b/training/modules/ti.py deleted file mode 100644 index 2db6f88..0000000 --- a/training/modules/ti.py +++ /dev/null @@ -1,284 +0,0 @@ -from typing import Literal -from functools import partial -from contextlib import contextmanager, nullcontext - -import torch - -from slugify import slugify - -from accelerate import Accelerator -from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel - -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.tokenizer import MultiCLIPTokenizer - -from training.common import TrainingSetup, get_scheduler, train_loop, loss_step -from training.util import EMAModel, CheckpointerBase - - -class Checkpointer(CheckpointerBase): - def __init__( - self, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, - ema_embeddings: EMAModel, - weight_dtype: torch.dtype, - scheduler, - placeholder_token, - placeholder_token_ids, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.ema_embeddings = ema_embeddings - self.scheduler = scheduler - self.placeholder_token = placeholder_token - self.placeholder_token_ids = placeholder_token_ids - - @torch.no_grad() - def checkpoint(self, step, postfix): - print("Saving checkpoint for step %d..." % step) - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = nullcontext() - if self.ema_embeddings is not None: - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - del text_encoder - - @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = nullcontext() - if self.ema_embeddings is not None: - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - - text_encoder.to(dtype=orig_dtype) - - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def train_ti( - setup: TrainingSetup, - num_train_epochs: int = 100, - num_class_images: int = 0, - prior_loss_weight: float = 1.0, - use_ema: bool = False, - ema_inv_gamma: float = 1.0, - ema_power: float = 4/5, - ema_max_decay: float = .9999, - adam_beta1: float = 0.9, - adam_beta2: float = 0.999, - adam_weight_decay: float = 0, - adam_epsilon: float = 1e-08, - adam_amsgrad: bool = False, - lr_scheduler: Literal[ - "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" - ] = "one_cycle", - lr_min_lr: float = 0.04, - lr_warmup_func: Literal["linear", "cos"] = "cos", - lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", - lr_warmup_exp: int = 1, - lr_annealing_exp: int = 1, - lr_cycles: int = 1, - lr_warmup_epochs: int = 10, - emb_decay_target: float = 0.4, - emb_decay_factor: float = 1, - emb_decay_start: float = 1e-4, - sample_image_size: int = 768, - sample_batch_size: int = 1, - sample_batches: int = 1, - sample_frequency: int = 10, - sample_steps: int = 20, - checkpoint_frequency: int = 50, - global_step_offset: int = 0, -): - if use_ema: - ema_embeddings = EMAModel( - setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=ema_inv_gamma, - power=ema_power, - max_value=ema_max_decay, - ) - else: - ema_embeddings = None - - setup.text_encoder.requires_grad_(True) - setup.text_encoder.text_model.encoder.requires_grad_(False) - setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) - setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - - # Initialize the optimizer - optimizer = setup.optimizer_class( - setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=setup.learning_rate, - betas=(adam_beta1, adam_beta2), - weight_decay=adam_weight_decay, - eps=adam_epsilon, - amsgrad=adam_amsgrad, - ) - - lr_scheduler = get_scheduler( - lr_scheduler, - optimizer=optimizer, - min_lr=lr_min_lr, - warmup_func=lr_warmup_func, - annealing_func=lr_annealing_func, - warmup_exp=lr_warmup_exp, - annealing_exp=lr_annealing_exp, - cycles=lr_cycles, - train_epochs=num_train_epochs, - warmup_epochs=lr_warmup_epochs, - num_training_steps_per_epoch=len(setup.train_dataloader), - gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps - ) - - text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( - setup.text_encoder, optimizer, lr_scheduler - ) - - # Move vae and unet to device - setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) - setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) - - if use_ema: - ema_embeddings.to(setup.accelerator.device) - - setup.unet.train() - - @contextmanager - def on_train(epoch: int): - try: - setup.tokenizer.train() - yield - finally: - pass - - @contextmanager - def on_eval(): - try: - setup.tokenizer.eval() - - ema_context = nullcontext() - if use_ema: - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) - ) - - if use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(): - if use_ema: - return {"ema_decay": ema_embeddings.decay} - return {} - - loss_step_ = partial( - loss_step, - setup.vae, - setup.noise_scheduler, - setup.unet, - text_encoder, - num_class_images != 0, - prior_loss_weight, - setup.seed, - ) - - checkpointer = Checkpointer( - accelerator=setup.accelerator, - vae=setup.vae, - unet=setup.unet, - tokenizer=setup.tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - weight_dtype=setup.weight_dtype, - scheduler=setup.checkpoint_scheduler, - placeholder_token=setup.placeholder_token, - placeholder_token_ids=setup.placeholder_token_ids, - train_dataloader=setup.train_dataloader, - val_dataloader=setup.val_dataloader, - output_dir=setup.output_dir, - seed=setup.seed, - sample_image_size=sample_image_size, - sample_batch_size=sample_batch_size, - sample_batches=sample_batches - ) - - if setup.accelerator.is_main_process: - setup.accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=setup.accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - checkpointer=checkpointer, - train_dataloader=setup.train_dataloader, - val_dataloader=setup.val_dataloader, - loss_step=loss_step_, - sample_frequency=sample_frequency, - sample_steps=sample_steps, - checkpoint_frequency=checkpoint_frequency, - global_step_offset=global_step_offset, - num_epochs=num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) diff --git a/training/optimization.py b/training/optimization.py index dd84f9c..5db7794 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,6 +5,8 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR +from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup + class OneCyclePhase(NamedTuple): step_min: int @@ -83,3 +85,54 @@ def get_one_cycle_schedule( return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_scheduler( + id: str, + optimizer: torch.optim.Optimizer, + num_training_steps_per_epoch: int, + gradient_accumulation_steps: int, + min_lr: float = 0.04, + warmup_func: str = "cos", + annealing_func: str = "cos", + warmup_exp: int = 1, + annealing_exp: int = 1, + cycles: int = 1, + train_epochs: int = 100, + warmup_epochs: int = 10, +): + num_training_steps_per_epoch = math.ceil( + num_training_steps_per_epoch / gradient_accumulation_steps + ) * gradient_accumulation_steps + num_training_steps = train_epochs * num_training_steps_per_epoch + num_warmup_steps = warmup_epochs * num_training_steps_per_epoch + + if id == "one_cycle": + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=num_training_steps, + warmup=warmup_func, + annealing=annealing_func, + warmup_exp=warmup_exp, + annealing_exp=annealing_exp, + min_lr=min_lr, + ) + elif id == "cosine_with_restarts": + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=cycles, + ) + else: + lr_scheduler = get_scheduler_( + id, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + return lr_scheduler -- cgit v1.2.3-70-g09d2