From 127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 18:59:26 +0100 Subject: More modularization --- training/common.py | 260 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 239 insertions(+), 21 deletions(-) (limited to 'training/common.py') diff --git a/training/common.py b/training/common.py index 180396e..73ce814 100644 --- a/training/common.py +++ b/training/common.py @@ -1,46 +1,77 @@ import math +from pathlib import Path from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union +from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple +import datetime +import logging import torch import torch.nn.functional as F from torch.utils.data import DataLoader from accelerate import Accelerator -from transformers import CLIPTokenizer, CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from accelerate.utils import LoggerType, set_seed +from transformers import CLIPTextModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup from tqdm.auto import tqdm +from slugify import slugify +from data.csv import VlpnDataModule, VlpnDataItem +from util import load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from models.clip.embeddings import patch_managed_embeddings from models.clip.util import get_extended_embeddings +from models.clip.tokenizer import MultiCLIPTokenizer from training.optimization import get_one_cycle_schedule from training.util import AverageMeter, CheckpointerBase +class TrainingSetup(NamedTuple): + accelerator: Accelerator + tokenizer: MultiCLIPTokenizer + text_encoder: CLIPTextModel + vae: AutoencoderKL + unet: UNet2DConditionModel + noise_scheduler: DDPMScheduler + checkpoint_scheduler: DPMSolverMultistepScheduler + optimizer_class: Callable + learning_rate: float + weight_dtype: torch.dtype + output_dir: Path + seed: int + train_dataloader: DataLoader + val_dataloader: DataLoader + placeholder_token: list[str] + placeholder_token_ids: list[list[int]] + + def noop(*args, **kwards): pass +def noop_ctx(*args, **kwards): + return nullcontext() + + def noop_on_log(): return {} def get_scheduler( id: str, - min_lr: float, - lr: float, - warmup_func: str, - annealing_func: str, - warmup_exp: int, - annealing_exp: int, - cycles: int, - train_epochs: int, - warmup_epochs: int, optimizer: torch.optim.Optimizer, num_training_steps_per_epoch: int, gradient_accumulation_steps: int, + min_lr: float = 0.04, + warmup_func: str = "cos", + annealing_func: str = "cos", + warmup_exp: int = 1, + annealing_exp: int = 1, + cycles: int = 1, + train_epochs: int = 100, + warmup_epochs: int = 10, ): num_training_steps_per_epoch = math.ceil( num_training_steps_per_epoch / gradient_accumulation_steps @@ -49,8 +80,6 @@ def get_scheduler( num_warmup_steps = warmup_epochs * num_training_steps_per_epoch if id == "one_cycle": - min_lr = 0.04 if min_lr is None else min_lr / lr - lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=num_training_steps, @@ -133,6 +162,196 @@ def generate_class_images( torch.cuda.empty_cache() +def train_setup( + output_dir: str, + project: str, + pretrained_model_name_or_path: str, + learning_rate: float, + data_file: str, + gradient_accumulation_steps: int = 1, + mixed_precision: Literal["no", "fp16", "bf16"] = "no", + seed: Optional[int] = None, + vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto", + vector_dropout: float = 0.1, + gradient_checkpointing: bool = True, + embeddings_dir: Optional[str] = None, + placeholder_token: list[str] = [], + initializer_token: list[str] = [], + num_vectors: int = 1, + scale_lr: bool = False, + use_8bit_adam: bool = False, + train_batch_size: int = 1, + class_image_dir: Optional[str] = None, + num_class_images: int = 0, + resolution: int = 768, + num_buckets: int = 0, + progressive_buckets: bool = False, + bucket_step_size: int = 64, + bucket_max_pixels: Optional[int] = None, + tag_dropout: float = 0.1, + tag_shuffle: bool = True, + data_template: str = "template", + valid_set_size: Optional[int] = None, + valid_set_repeat: int = 1, + data_filter: Optional[Callable[[VlpnDataItem], bool]] = None, + sample_batch_size: int = 1, + sample_image_size: int = 768, + sample_steps: int = 20, +) -> TrainingSetup: + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + output_dir = Path(output_dir).joinpath(slugify(project), now) + output_dir.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{output_dir}", + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision + ) + + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + + seed = seed or (torch.random.seed() >> 32) + set_seed(seed) + + # Load the tokenizer and add the placeholder token as a additional special token + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer.set_use_vector_shuffle(vector_shuffle) + tokenizer.set_dropout(vector_dropout) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder='scheduler') + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + embeddings = patch_managed_embeddings(text_encoder) + + if embeddings_dir is not None: + embeddings_dir = Path(embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = [ + tokenizer.encode(token, add_special_tokens=False) + for token in initializer_token + ] + + placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors) + embeddings.resize(len(tokenizer)) + + for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): + embeddings.add_embed(new_id, init_ids) + + init_ratios = [ + f"{len(init_ids)} / {len(new_id)}" + for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids) + ] + + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}") + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * + train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + weight_dtype = torch.float32 + if mixed_precision == "fp16": + weight_dtype = torch.float16 + elif mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + datamodule = VlpnDataModule( + data_file=data_file, + batch_size=train_batch_size, + tokenizer=tokenizer, + class_subdir=class_image_dir, + num_class_images=num_class_images, + size=resolution, + num_buckets=num_buckets, + progressive_buckets=progressive_buckets, + bucket_step_size=bucket_step_size, + bucket_max_pixels=bucket_max_pixels, + dropout=tag_dropout, + shuffle=tag_shuffle, + template_key=data_template, + valid_set_size=valid_set_size, + valid_set_repeat=valid_set_repeat, + seed=seed, + filter=data_filter, + dtype=weight_dtype + ) + datamodule.setup() + + train_dataloader = datamodule.train_dataloader + val_dataloader = datamodule.val_dataloader + + train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader) + + if num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + checkpoint_scheduler, + datamodule.data_train, + sample_batch_size, + sample_image_size, + sample_steps + ) + + return TrainingSetup( + accelerator=accelerator, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + noise_scheduler=noise_scheduler, + checkpoint_scheduler=checkpoint_scheduler, + optimizer_class=optimizer_class, + learning_rate=learning_rate, + output_dir=output_dir, + weight_dtype=weight_dtype, + seed=seed, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + placeholder_token=placeholder_token, + placeholder_token_ids=placeholder_token_ids + ) + + def loss_step( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, @@ -221,15 +440,14 @@ def train_loop( sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, - gradient_accumulation_steps: int = 1, num_epochs: int = 100, on_log: Callable[[], dict[str, Any]] = noop_on_log, - on_train: Callable[[], _GeneratorContextManager] = nullcontext, - on_before_optimize: Callable[[], None] = noop, + on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, + on_before_optimize: Callable[[int], None] = noop, on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = nullcontext + on_eval: Callable[[], _GeneratorContextManager] = noop_ctx ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) num_val_steps_per_epoch = len(val_dataloader) num_training_steps = num_training_steps_per_epoch * num_epochs @@ -273,14 +491,14 @@ def train_loop( model.train() - with on_train(): + with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - on_before_optimize() + on_before_optimize(epoch) optimizer.step() lr_scheduler.step() -- cgit v1.2.3-54-g00ecf