From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- training/common.py | 260 ++++++++++++++++++++++++++++++++++--- training/lr.py | 14 +- training/modules/dreambooth.py | 0 training/modules/lora.py | 0 training/modules/ti.py | 284 +++++++++++++++++++++++++++++++++++++++++ training/util.py | 15 ++- 6 files changed, 541 insertions(+), 32 deletions(-) create mode 100644 training/modules/dreambooth.py create mode 100644 training/modules/lora.py create mode 100644 training/modules/ti.py (limited to 'training') diff --git a/training/common.py b/training/common.py index 180396e..73ce814 100644 --- a/training/common.py +++ b/training/common.py @@ -1,46 +1,77 @@ import math +from pathlib import Path from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union +from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple +import datetime +import logging import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator -from transformers import CLIPTokenizer, CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from accelerate.utils import LoggerType, set_seed +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from tqdm.auto import tqdm +from slugify import slugify +from data.csv import VlpnDataModule, VlpnDataItem +from util import load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.embeddings import patch_managed_embeddings from models.clip.util import get_extended_embeddings +from models.clip.tokenizer import MultiCLIPTokenizer from training.optimization import get_one_cycle_schedule from training.util import AverageMeter, CheckpointerBase +class TrainingSetup(NamedTuple): + accelerator: Accelerator + tokenizer: MultiCLIPTokenizer + text_encoder: CLIPTextModel + vae: AutoencoderKL + unet: UNet2DConditionModel + noise_scheduler: DDPMScheduler + checkpoint_scheduler: DPMSolverMultistepScheduler + optimizer_class: Callable + learning_rate: float + weight_dtype: torch.dtype + output_dir: Path + seed: int + train_dataloader: DataLoader + val_dataloader: DataLoader + placeholder_token: list[str] + placeholder_token_ids: list[list[int]] + + def noop(*args, **kwards): pass +def noop_ctx(*args, **kwards): + return nullcontext() + + def noop_on_log(): return {} def get_scheduler( id: str, - min_lr: float, - lr: float, - warmup_func: str, - annealing_func: str, - warmup_exp: int, - annealing_exp: int, - cycles: int, - train_epochs: int, - warmup_epochs: int, optimizer: torch.optim.Optimizer, num_training_steps_per_epoch: int, gradient_accumulation_steps: int, + min_lr: float = 0.04, + warmup_func: str = "cos", + annealing_func: str = "cos", + warmup_exp: int = 1, + annealing_exp: int = 1, + cycles: int = 1, + train_epochs: int = 100, + warmup_epochs: int = 10, ): num_training_steps_per_epoch = math.ceil( num_training_steps_per_epoch / gradient_accumulation_steps @@ -49,8 +80,6 @@ def get_scheduler( num_warmup_steps = warmup_epochs * num_training_steps_per_epoch if id == "one_cycle": - min_lr = 0.04 if min_lr is None else min_lr / lr - lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=num_training_steps, @@ -133,6 +162,196 @@ def generate_class_images( torch.cuda.empty_cache() +def train_setup( + output_dir: str, + project: str, + pretrained_model_name_or_path: str, + learning_rate: float, + data_file: str, + gradient_accumulation_steps: int = 1, + mixed_precision: Literal["no", "fp16", "bf16"] = "no", + seed: Optional[int] = None, + vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto", + vector_dropout: float = 0.1, + gradient_checkpointing: bool = True, + embeddings_dir: Optional[str] = None, + placeholder_token: list[str] = [], + initializer_token: list[str] = [], + num_vectors: int = 1, + scale_lr: bool = False, + use_8bit_adam: bool = False, + train_batch_size: int = 1, + class_image_dir: Optional[str] = None, + num_class_images: int = 0, + resolution: int = 768, + num_buckets: int = 0, + progressive_buckets: bool = False, + bucket_step_size: int = 64, + bucket_max_pixels: Optional[int] = None, + tag_dropout: float = 0.1, + tag_shuffle: bool = True, + data_template: str = "template", + valid_set_size: Optional[int] = None, + valid_set_repeat: int = 1, + data_filter: Optional[Callable[[VlpnDataItem], bool]] = None, + sample_batch_size: int = 1, + sample_image_size: int = 768, + sample_steps: int = 20, +) -> TrainingSetup: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(output_dir).joinpath(slugify(project), now) + output_dir.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{output_dir}", + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision + ) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + seed = seed or (torch.random.seed() >> 32) + set_seed(seed) + + # Load the tokenizer and add the placeholder token as a additional special token + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer.set_use_vector_shuffle(vector_shuffle) + tokenizer.set_dropout(vector_dropout) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder='scheduler') + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + embeddings = patch_managed_embeddings(text_encoder) + + if embeddings_dir is not None: + embeddings_dir = Path(embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) + for token in initializer_token + ] + + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors) + embeddings.resize(len(tokenizer)) + + for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): + embeddings.add_embed(new_id, init_ids) + + init_ratios = [ + f"{len(init_ids)} / {len(new_id)}" + for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids) + ] + + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}") + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * + train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + weight_dtype = torch.float32 + if mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + datamodule = VlpnDataModule( + data_file=data_file, + batch_size=train_batch_size, + tokenizer=tokenizer, + class_subdir=class_image_dir, + num_class_images=num_class_images, + size=resolution, + num_buckets=num_buckets, + progressive_buckets=progressive_buckets, + bucket_step_size=bucket_step_size, + bucket_max_pixels=bucket_max_pixels, + dropout=tag_dropout, + shuffle=tag_shuffle, + template_key=data_template, + valid_set_size=valid_set_size, + valid_set_repeat=valid_set_repeat, + seed=seed, + filter=data_filter, + dtype=weight_dtype + ) + datamodule.setup() + + train_dataloader = datamodule.train_dataloader + val_dataloader = datamodule.val_dataloader + + train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader) + + if num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + checkpoint_scheduler, + datamodule.data_train, + sample_batch_size, + sample_image_size, + sample_steps + ) + + return TrainingSetup( + accelerator=accelerator, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + noise_scheduler=noise_scheduler, + checkpoint_scheduler=checkpoint_scheduler, + optimizer_class=optimizer_class, + learning_rate=learning_rate, + output_dir=output_dir, + weight_dtype=weight_dtype, + seed=seed, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + placeholder_token=placeholder_token, + placeholder_token_ids=placeholder_token_ids + ) + + def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, @@ -221,15 +440,14 @@ def train_loop( sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, - gradient_accumulation_steps: int = 1, num_epochs: int = 100, on_log: Callable[[], dict[str, Any]] = noop_on_log, - on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_before_optimize: Callable[[], None] = noop, + on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, + on_before_optimize: Callable[[int], None] = noop, on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = nullcontext + on_eval: Callable[[], _GeneratorContextManager] = noop_ctx ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) num_training_steps = num_training_steps_per_epoch * num_epochs @@ -273,14 +491,14 @@ def train_loop( model.train() - with on_train(): + with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - on_before_optimize() + on_before_optimize(epoch) optimizer.step() lr_scheduler.step() diff --git a/training/lr.py b/training/lr.py index 84e30a0..7584ba2 100644 --- a/training/lr.py +++ b/training/lr.py @@ -16,6 +16,10 @@ def noop(*args, **kwards): pass +def noop_ctx(*args, **kwards): + return nullcontext() + + class LRFinder(): def __init__( self, @@ -25,10 +29,10 @@ class LRFinder(): train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_before_optimize: Callable[[], None] = noop, + on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, + on_before_optimize: Callable[[int], None] = noop, on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = nullcontext + on_eval: Callable[[], _GeneratorContextManager] = noop_ctx ): self.accelerator = accelerator self.model = model @@ -86,7 +90,7 @@ class LRFinder(): self.model.train() - with self.on_train(): + with self.on_train(epoch): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break @@ -96,7 +100,7 @@ class LRFinder(): self.accelerator.backward(loss) - self.on_before_optimize() + self.on_before_optimize(epoch) self.optimizer.step() lr_scheduler.step() diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py new file mode 100644 index 0000000..e69de29 diff --git a/training/modules/lora.py b/training/modules/lora.py new file mode 100644 index 0000000..e69de29 diff --git a/training/modules/ti.py b/training/modules/ti.py new file mode 100644 index 0000000..2db6f88 --- /dev/null +++ b/training/modules/ti.py @@ -0,0 +1,284 @@ +from typing import Literal +from functools import partial +from contextlib import contextmanager, nullcontext + +import torch + +from slugify import slugify + +from accelerate import Accelerator +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, UNet2DConditionModel + +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.tokenizer import MultiCLIPTokenizer + +from training.common import TrainingSetup, get_scheduler, train_loop, loss_step +from training.util import EMAModel, CheckpointerBase + + +class Checkpointer(CheckpointerBase): + def __init__( + self, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, + weight_dtype: torch.dtype, + scheduler, + placeholder_token, + placeholder_token_ids, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.weight_dtype = weight_dtype + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings + self.scheduler = scheduler + self.placeholder_token = placeholder_token + self.placeholder_token_ids = placeholder_token_ids + + @torch.no_grad() + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = nullcontext() + if self.ema_embeddings is not None: + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + del text_encoder + + @torch.no_grad() + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = nullcontext() + if self.ema_embeddings is not None: + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + + text_encoder.to(dtype=orig_dtype) + + del text_encoder + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def train_ti( + setup: TrainingSetup, + num_train_epochs: int = 100, + num_class_images: int = 0, + prior_loss_weight: float = 1.0, + use_ema: bool = False, + ema_inv_gamma: float = 1.0, + ema_power: float = 4/5, + ema_max_decay: float = .9999, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 0, + adam_epsilon: float = 1e-08, + adam_amsgrad: bool = False, + lr_scheduler: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" + ] = "one_cycle", + lr_min_lr: float = 0.04, + lr_warmup_func: Literal["linear", "cos"] = "cos", + lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", + lr_warmup_exp: int = 1, + lr_annealing_exp: int = 1, + lr_cycles: int = 1, + lr_warmup_epochs: int = 10, + emb_decay_target: float = 0.4, + emb_decay_factor: float = 1, + emb_decay_start: float = 1e-4, + sample_image_size: int = 768, + sample_batch_size: int = 1, + sample_batches: int = 1, + sample_frequency: int = 10, + sample_steps: int = 20, + checkpoint_frequency: int = 50, + global_step_offset: int = 0, +): + if use_ema: + ema_embeddings = EMAModel( + setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + else: + ema_embeddings = None + + setup.text_encoder.requires_grad_(True) + setup.text_encoder.text_model.encoder.requires_grad_(False) + setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) + setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) + + # Initialize the optimizer + optimizer = setup.optimizer_class( + setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=setup.learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + amsgrad=adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + min_lr=lr_min_lr, + warmup_func=lr_warmup_func, + annealing_func=lr_annealing_func, + warmup_exp=lr_warmup_exp, + annealing_exp=lr_annealing_exp, + cycles=lr_cycles, + train_epochs=num_train_epochs, + warmup_epochs=lr_warmup_epochs, + num_training_steps_per_epoch=len(setup.train_dataloader), + gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps + ) + + text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( + setup.text_encoder, optimizer, lr_scheduler + ) + + # Move vae and unet to device + setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) + setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) + + if use_ema: + ema_embeddings.to(setup.accelerator.device) + + setup.unet.train() + + @contextmanager + def on_train(epoch: int): + try: + setup.tokenizer.train() + yield + finally: + pass + + @contextmanager + def on_eval(): + try: + setup.tokenizer.eval() + + ema_context = nullcontext() + if use_ema: + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + with ema_context: + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(lr: float): + text_encoder.text_model.embeddings.normalize( + emb_decay_target, + min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) + ) + + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + + loss_step_ = partial( + loss_step, + setup.vae, + setup.noise_scheduler, + setup.unet, + text_encoder, + num_class_images != 0, + prior_loss_weight, + setup.seed, + ) + + checkpointer = Checkpointer( + accelerator=setup.accelerator, + vae=setup.vae, + unet=setup.unet, + tokenizer=setup.tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + weight_dtype=setup.weight_dtype, + scheduler=setup.checkpoint_scheduler, + placeholder_token=setup.placeholder_token, + placeholder_token_ids=setup.placeholder_token_ids, + train_dataloader=setup.train_dataloader, + val_dataloader=setup.val_dataloader, + output_dir=setup.output_dir, + seed=setup.seed, + sample_image_size=sample_image_size, + sample_batch_size=sample_batch_size, + sample_batches=sample_batches + ) + + if setup.accelerator.is_main_process: + setup.accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=setup.accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=setup.train_dataloader, + val_dataloader=setup.val_dataloader, + loss_step=loss_step_, + sample_frequency=sample_frequency, + sample_steps=sample_steps, + checkpoint_frequency=checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) diff --git a/training/util.py b/training/util.py index 0ec2032..cc4cdee 100644 --- a/training/util.py +++ b/training/util.py @@ -41,14 +41,16 @@ class AverageMeter: class CheckpointerBase: def __init__( self, - datamodule, + train_dataloader, + val_dataloader, output_dir: Path, sample_image_size: int, sample_batches: int, sample_batch_size: int, seed: Optional[int] = None ): - self.datamodule = datamodule + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed if seed is not None else torch.random.seed() @@ -70,15 +72,16 @@ class CheckpointerBase: ): samples_path = Path(self.output_dir).joinpath("samples") - train_data = self.datamodule.train_dataloader - val_data = self.datamodule.val_dataloader - generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) grid_cols = min(self.sample_batch_size, 4) grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols - for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: + for pool, data, gen in [ + ("stable", self.val_dataloader, generator), + ("val", self.val_dataloader, None), + ("train", self.train_dataloader, None) + ]: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) -- cgit v1.2.3-54-g00ecf