From 3e7fbb7dce321435bbbb81361debfbc499bf9231 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 22:25:30 +0100 Subject: Reverted modularization mostly --- training/common.py | 264 +++----------------------------------- training/modules/dreambooth.py | 0 training/modules/lora.py | 0 training/modules/ti.py | 284 ----------------------------------------- training/optimization.py | 53 ++++++++ 5 files changed, 70 insertions(+), 531 deletions(-) delete mode 100644 training/modules/dreambooth.py delete mode 100644 training/modules/lora.py delete mode 100644 training/modules/ti.py (limited to 'training') diff --git a/training/common.py b/training/common.py index 73ce814..b6964a3 100644 --- a/training/common.py +++ b/training/common.py @@ -1,52 +1,24 @@ import math -from pathlib import Path from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple -import datetime -import logging +from typing import Callable, Any, Tuple, Union import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator -from accelerate.utils import LoggerType, set_seed from transformers import CLIPTextModel from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler -from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from tqdm.auto import tqdm -from slugify import slugify -from data.csv import VlpnDataModule, VlpnDataItem -from util import load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.embeddings import patch_managed_embeddings +from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer -from training.optimization import get_one_cycle_schedule from training.util import AverageMeter, CheckpointerBase -class TrainingSetup(NamedTuple): - accelerator: Accelerator - tokenizer: MultiCLIPTokenizer - text_encoder: CLIPTextModel - vae: AutoencoderKL - unet: UNet2DConditionModel - noise_scheduler: DDPMScheduler - checkpoint_scheduler: DPMSolverMultistepScheduler - optimizer_class: Callable - learning_rate: float - weight_dtype: torch.dtype - output_dir: Path - seed: int - train_dataloader: DataLoader - val_dataloader: DataLoader - placeholder_token: list[str] - placeholder_token_ids: list[list[int]] - - def noop(*args, **kwards): pass @@ -59,57 +31,6 @@ def noop_on_log(): return {} -def get_scheduler( - id: str, - optimizer: torch.optim.Optimizer, - num_training_steps_per_epoch: int, - gradient_accumulation_steps: int, - min_lr: float = 0.04, - warmup_func: str = "cos", - annealing_func: str = "cos", - warmup_exp: int = 1, - annealing_exp: int = 1, - cycles: int = 1, - train_epochs: int = 100, - warmup_epochs: int = 10, -): - num_training_steps_per_epoch = math.ceil( - num_training_steps_per_epoch / gradient_accumulation_steps - ) * gradient_accumulation_steps - num_training_steps = train_epochs * num_training_steps_per_epoch - num_warmup_steps = warmup_epochs * num_training_steps_per_epoch - - if id == "one_cycle": - lr_scheduler = get_one_cycle_schedule( - optimizer=optimizer, - num_training_steps=num_training_steps, - warmup=warmup_func, - annealing=annealing_func, - warmup_exp=warmup_exp, - annealing_exp=annealing_exp, - min_lr=min_lr, - ) - elif id == "cosine_with_restarts": - if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) - - lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - num_cycles=cycles, - ) - else: - lr_scheduler = get_scheduler_( - id, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - ) - - return lr_scheduler - - def generate_class_images( accelerator, text_encoder, @@ -162,194 +83,43 @@ def generate_class_images( torch.cuda.empty_cache() -def train_setup( - output_dir: str, - project: str, - pretrained_model_name_or_path: str, - learning_rate: float, - data_file: str, - gradient_accumulation_steps: int = 1, - mixed_precision: Literal["no", "fp16", "bf16"] = "no", - seed: Optional[int] = None, - vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto", - vector_dropout: float = 0.1, - gradient_checkpointing: bool = True, - embeddings_dir: Optional[str] = None, - placeholder_token: list[str] = [], - initializer_token: list[str] = [], - num_vectors: int = 1, - scale_lr: bool = False, - use_8bit_adam: bool = False, - train_batch_size: int = 1, - class_image_dir: Optional[str] = None, - num_class_images: int = 0, - resolution: int = 768, - num_buckets: int = 0, - progressive_buckets: bool = False, - bucket_step_size: int = 64, - bucket_max_pixels: Optional[int] = None, - tag_dropout: float = 0.1, - tag_shuffle: bool = True, - data_template: str = "template", - valid_set_size: Optional[int] = None, - valid_set_repeat: int = 1, - data_filter: Optional[Callable[[VlpnDataItem], bool]] = None, - sample_batch_size: int = 1, - sample_image_size: int = 768, - sample_steps: int = 20, -) -> TrainingSetup: - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(output_dir).joinpath(slugify(project), now) - output_dir.mkdir(parents=True, exist_ok=True) - - accelerator = Accelerator( - log_with=LoggerType.TENSORBOARD, - logging_dir=f"{output_dir}", - gradient_accumulation_steps=gradient_accumulation_steps, - mixed_precision=mixed_precision - ) - - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) - - seed = seed or (torch.random.seed() >> 32) - set_seed(seed) - - # Load the tokenizer and add the placeholder token as a additional special token +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - tokenizer.set_use_vector_shuffle(vector_shuffle) - tokenizer.set_dropout(vector_dropout) - - # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) - if gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - embeddings = patch_managed_embeddings(text_encoder) - if embeddings_dir is not None: - embeddings_dir = Path(embeddings_dir) - if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - raise ValueError("--embeddings_dir must point to an existing directory") + return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - # Convert the initializer_token, placeholder_token to ids +def add_placeholder_tokens( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + placeholder_tokens: list[str], + initializer_tokens: list[str], + num_vectors: Union[list[int], int] +): initializer_token_ids = [ tokenizer.encode(token, add_special_tokens=False) - for token in initializer_token + for token in initializer_tokens ] + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) - placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors) embeddings.resize(len(tokenizer)) - for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(new_id, init_ids) - - init_ratios = [ - f"{len(init_ids)} / {len(new_id)}" - for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids) - ] - - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}") + for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): + embeddings.add_embed(placeholder_token_id, initializer_token_id) - vae.requires_grad_(False) - unet.requires_grad_(False) - text_encoder.requires_grad_(False) - - if scale_lr: - learning_rate = ( - learning_rate * gradient_accumulation_steps * - train_batch_size * accelerator.num_processes - ) - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - weight_dtype = torch.float32 - if mixed_precision == "fp16": - weight_dtype = torch.float16 - elif mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - datamodule = VlpnDataModule( - data_file=data_file, - batch_size=train_batch_size, - tokenizer=tokenizer, - class_subdir=class_image_dir, - num_class_images=num_class_images, - size=resolution, - num_buckets=num_buckets, - progressive_buckets=progressive_buckets, - bucket_step_size=bucket_step_size, - bucket_max_pixels=bucket_max_pixels, - dropout=tag_dropout, - shuffle=tag_shuffle, - template_key=data_template, - valid_set_size=valid_set_size, - valid_set_repeat=valid_set_repeat, - seed=seed, - filter=data_filter, - dtype=weight_dtype - ) - datamodule.setup() - - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - - train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader) - - if num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - checkpoint_scheduler, - datamodule.data_train, - sample_batch_size, - sample_image_size, - sample_steps - ) - - return TrainingSetup( - accelerator=accelerator, - tokenizer=tokenizer, - text_encoder=text_encoder, - vae=vae, - unet=unet, - noise_scheduler=noise_scheduler, - checkpoint_scheduler=checkpoint_scheduler, - optimizer_class=optimizer_class, - learning_rate=learning_rate, - output_dir=output_dir, - weight_dtype=weight_dtype, - seed=seed, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - placeholder_token=placeholder_token, - placeholder_token_ids=placeholder_token_ids - ) + return placeholder_token_ids def loss_step( diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py deleted file mode 100644 index e69de29..0000000 diff --git a/training/modules/lora.py b/training/modules/lora.py deleted file mode 100644 index e69de29..0000000 diff --git a/training/modules/ti.py b/training/modules/ti.py deleted file mode 100644 index 2db6f88..0000000 --- a/training/modules/ti.py +++ /dev/null @@ -1,284 +0,0 @@ -from typing import Literal -from functools import partial -from contextlib import contextmanager, nullcontext - -import torch - -from slugify import slugify - -from accelerate import Accelerator -from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel - -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.tokenizer import MultiCLIPTokenizer - -from training.common import TrainingSetup, get_scheduler, train_loop, loss_step -from training.util import EMAModel, CheckpointerBase - - -class Checkpointer(CheckpointerBase): - def __init__( - self, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, - ema_embeddings: EMAModel, - weight_dtype: torch.dtype, - scheduler, - placeholder_token, - placeholder_token_ids, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.ema_embeddings = ema_embeddings - self.scheduler = scheduler - self.placeholder_token = placeholder_token - self.placeholder_token_ids = placeholder_token_ids - - @torch.no_grad() - def checkpoint(self, step, postfix): - print("Saving checkpoint for step %d..." % step) - - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) - - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = nullcontext() - if self.ema_embeddings is not None: - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids): - text_encoder.text_model.embeddings.save_embed( - ids, - checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") - ) - - del text_encoder - - @torch.no_grad() - def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = nullcontext() - if self.ema_embeddings is not None: - ema_context = self.ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - orig_dtype = text_encoder.dtype - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=self.unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) - - text_encoder.to(dtype=orig_dtype) - - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def train_ti( - setup: TrainingSetup, - num_train_epochs: int = 100, - num_class_images: int = 0, - prior_loss_weight: float = 1.0, - use_ema: bool = False, - ema_inv_gamma: float = 1.0, - ema_power: float = 4/5, - ema_max_decay: float = .9999, - adam_beta1: float = 0.9, - adam_beta2: float = 0.999, - adam_weight_decay: float = 0, - adam_epsilon: float = 1e-08, - adam_amsgrad: bool = False, - lr_scheduler: Literal[ - "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle" - ] = "one_cycle", - lr_min_lr: float = 0.04, - lr_warmup_func: Literal["linear", "cos"] = "cos", - lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos", - lr_warmup_exp: int = 1, - lr_annealing_exp: int = 1, - lr_cycles: int = 1, - lr_warmup_epochs: int = 10, - emb_decay_target: float = 0.4, - emb_decay_factor: float = 1, - emb_decay_start: float = 1e-4, - sample_image_size: int = 768, - sample_batch_size: int = 1, - sample_batches: int = 1, - sample_frequency: int = 10, - sample_steps: int = 20, - checkpoint_frequency: int = 50, - global_step_offset: int = 0, -): - if use_ema: - ema_embeddings = EMAModel( - setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=ema_inv_gamma, - power=ema_power, - max_value=ema_max_decay, - ) - else: - ema_embeddings = None - - setup.text_encoder.requires_grad_(True) - setup.text_encoder.text_model.encoder.requires_grad_(False) - setup.text_encoder.text_model.final_layer_norm.requires_grad_(False) - setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - - # Initialize the optimizer - optimizer = setup.optimizer_class( - setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=setup.learning_rate, - betas=(adam_beta1, adam_beta2), - weight_decay=adam_weight_decay, - eps=adam_epsilon, - amsgrad=adam_amsgrad, - ) - - lr_scheduler = get_scheduler( - lr_scheduler, - optimizer=optimizer, - min_lr=lr_min_lr, - warmup_func=lr_warmup_func, - annealing_func=lr_annealing_func, - warmup_exp=lr_warmup_exp, - annealing_exp=lr_annealing_exp, - cycles=lr_cycles, - train_epochs=num_train_epochs, - warmup_epochs=lr_warmup_epochs, - num_training_steps_per_epoch=len(setup.train_dataloader), - gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps - ) - - text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare( - setup.text_encoder, optimizer, lr_scheduler - ) - - # Move vae and unet to device - setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype) - setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype) - - if use_ema: - ema_embeddings.to(setup.accelerator.device) - - setup.unet.train() - - @contextmanager - def on_train(epoch: int): - try: - setup.tokenizer.train() - yield - finally: - pass - - @contextmanager - def on_eval(): - try: - setup.tokenizer.eval() - - ema_context = nullcontext() - if use_ema: - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - with ema_context: - yield - finally: - pass - - @torch.no_grad() - def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start)))) - ) - - if use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - - def on_log(): - if use_ema: - return {"ema_decay": ema_embeddings.decay} - return {} - - loss_step_ = partial( - loss_step, - setup.vae, - setup.noise_scheduler, - setup.unet, - text_encoder, - num_class_images != 0, - prior_loss_weight, - setup.seed, - ) - - checkpointer = Checkpointer( - accelerator=setup.accelerator, - vae=setup.vae, - unet=setup.unet, - tokenizer=setup.tokenizer, - text_encoder=text_encoder, - ema_embeddings=ema_embeddings, - weight_dtype=setup.weight_dtype, - scheduler=setup.checkpoint_scheduler, - placeholder_token=setup.placeholder_token, - placeholder_token_ids=setup.placeholder_token_ids, - train_dataloader=setup.train_dataloader, - val_dataloader=setup.val_dataloader, - output_dir=setup.output_dir, - seed=setup.seed, - sample_image_size=sample_image_size, - sample_batch_size=sample_batch_size, - sample_batches=sample_batches - ) - - if setup.accelerator.is_main_process: - setup.accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=setup.accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - checkpointer=checkpointer, - train_dataloader=setup.train_dataloader, - val_dataloader=setup.val_dataloader, - loss_step=loss_step_, - sample_frequency=sample_frequency, - sample_steps=sample_steps, - checkpoint_frequency=checkpoint_frequency, - global_step_offset=global_step_offset, - num_epochs=num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) diff --git a/training/optimization.py b/training/optimization.py index dd84f9c..5db7794 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,6 +5,8 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR +from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup + class OneCyclePhase(NamedTuple): step_min: int @@ -83,3 +85,54 @@ def get_one_cycle_schedule( return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_scheduler( + id: str, + optimizer: torch.optim.Optimizer, + num_training_steps_per_epoch: int, + gradient_accumulation_steps: int, + min_lr: float = 0.04, + warmup_func: str = "cos", + annealing_func: str = "cos", + warmup_exp: int = 1, + annealing_exp: int = 1, + cycles: int = 1, + train_epochs: int = 100, + warmup_epochs: int = 10, +): + num_training_steps_per_epoch = math.ceil( + num_training_steps_per_epoch / gradient_accumulation_steps + ) * gradient_accumulation_steps + num_training_steps = train_epochs * num_training_steps_per_epoch + num_warmup_steps = warmup_epochs * num_training_steps_per_epoch + + if id == "one_cycle": + lr_scheduler = get_one_cycle_schedule( + optimizer=optimizer, + num_training_steps=num_training_steps, + warmup=warmup_func, + annealing=annealing_func, + warmup_exp=warmup_exp, + annealing_exp=annealing_exp, + min_lr=min_lr, + ) + elif id == "cosine_with_restarts": + if cycles is None: + cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=cycles, + ) + else: + lr_scheduler = get_scheduler_( + id, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + + return lr_scheduler -- cgit v1.2.3-54-g00ecf