From 3e7fbb7dce321435bbbb81361debfbc499bf9231 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 22:25:30 +0100 Subject: Reverted modularization mostly --- train_ti.py | 467 +++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 386 insertions(+), 81 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 3a55f40..61195f6 100644 --- a/train_ti.py +++ b/train_ti.py @@ -1,15 +1,29 @@ import argparse +import datetime +import logging +from functools import partial +from pathlib import Path +from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint +from accelerate import Accelerator from accelerate.logging import get_logger - -from util import load_config -from data.csv import VlpnDataItem -from training.common import train_setup -from training.modules.ti import train_ti -from training.util import save_args +from accelerate.utils import LoggerType, set_seed +from diffusers import AutoencoderKL, UNet2DConditionModel +import matplotlib.pyplot as plt +from transformers import CLIPTextModel +from slugify import slugify + +from util import load_config, load_embeddings_from_dir +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from data.csv import VlpnDataModule, VlpnDataItem +from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models +from training.optimization import get_scheduler +from training.lr import LRFinder +from training.util import CheckpointerBase, EMAModel, save_args +from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -52,13 +66,13 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--placeholder_token", + "--placeholder_tokens", type=str, nargs='*', help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--initializer_token", + "--initializer_tokens", type=str, nargs='*', help="A token to use as initializer word." @@ -439,29 +453,29 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.placeholder_token, str): - args.placeholder_token = [args.placeholder_token] + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] - if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - if len(args.initializer_token) == 0: - raise ValueError("You must specify --initializer_token") + if len(args.initializer_tokens) == 0: + raise ValueError("You must specify --initializer_tokens") - if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): - args.num_vectors = [args.num_vectors] * len(args.initializer_token) + args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) - if len(args.placeholder_token) != len(args.num_vectors): - raise ValueError("--placeholder_token and --num_vectors must have the same number of items") + if len(args.placeholder_tokens) != len(args.num_vectors): + raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -475,13 +489,197 @@ def parse_args(): return args +class Checkpointer(CheckpointerBase): + def __init__( + self, + weight_dtype, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, + scheduler, + placeholder_tokens, + placeholder_token_ids, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.weight_dtype = weight_dtype + self.accelerator = accelerator + self.vae = vae + self.unet = unet + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.ema_embeddings = ema_embeddings + self.scheduler = scheduler + self.placeholder_tokens = placeholder_tokens + self.placeholder_token_ids = placeholder_token_ids + + @torch.no_grad() + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() + + with ema_context: + for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): + text_encoder.text_model.embeddings.save_embed( + ids, + checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + ) + + del text_encoder + + @torch.no_grad() + def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): + text_encoder = self.accelerator.unwrap_model(self.text_encoder) + + ema_context = self.ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() + + with ema_context: + orig_dtype = text_encoder.dtype + text_encoder.to(dtype=self.weight_dtype) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=self.vae, + unet=self.unet, + tokenizer=self.tokenizer, + scheduler=self.scheduler, + ).to(self.accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) + + text_encoder.to(dtype=orig_dtype) + + del text_encoder + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def main(): args = parse_args() - def data_filter(item: VlpnDataItem): + global_step_offset = args.global_step + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.output_dir).joinpath(slugify(args.project), now) + basepath.mkdir(parents=True, exist_ok=True) + + accelerator = Accelerator( + log_with=LoggerType.TENSORBOARD, + logging_dir=f"{basepath}", + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision + ) + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + + args.seed = args.seed or (torch.random.seed() >> 32) + set_seed(args.seed) + + save_args(basepath, args) + + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( + args.pretrained_model_name_or_path) + + tokenizer.set_use_vector_shuffle(args.vector_shuffle) + tokenizer.set_dropout(args.vector_dropout) + + vae.enable_slicing() + vae.set_use_memory_efficient_attention_xformers(True) + unet.set_use_memory_efficient_attention_xformers(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + + if args.embeddings_dir is not None: + embeddings_dir = Path(args.embeddings_dir) + if not embeddings_dir.exists() or not embeddings_dir.is_dir(): + raise ValueError("--embeddings_dir must point to an existing directory") + + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + + placeholder_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors + ) + + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + + if args.use_ema: + ema_embeddings = EMAModel( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) + else: + ema_embeddings = None + + vae.requires_grad_(False) + unet.requires_grad_(False) + + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes + ) + + if args.find_lr: + args.learning_rate = 1e-5 + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + def keyword_filter(item: VlpnDataItem): cond1 = any( keyword in part - for keyword in args.placeholder_token + for keyword in args.placeholder_tokens for part in item.prompt ) cond3 = args.collection is None or args.collection in item.collection @@ -491,78 +689,185 @@ def main(): ) return cond1 and cond3 and cond4 - setup = train_setup( - output_dir=args.output_dir, - project=args.project, - pretrained_model_name_or_path=args.pretrained_model_name_or_path, - learning_rate=args.learning_rate, + datamodule = VlpnDataModule( data_file=args.train_data_file, - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - seed=args.seed, - vector_shuffle=args.vector_shuffle, - vector_dropout=args.vector_dropout, - gradient_checkpointing=args.gradient_checkpointing, - embeddings_dir=args.embeddings_dir, - placeholder_token=args.placeholder_token, - initializer_token=args.initializer_token, - num_vectors=args.num_vectors, - scale_lr=args.scale_lr, - use_8bit_adam=args.use_8bit_adam, - train_batch_size=args.train_batch_size, - class_image_dir=args.class_image_dir, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, num_class_images=args.num_class_images, - resolution=args.resolution, + size=args.resolution, num_buckets=args.num_buckets, progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - tag_dropout=args.tag_dropout, - tag_shuffle=not args.no_tag_shuffle, - data_template=args.train_data_template, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - data_filter=data_filter, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_steps=args.sample_steps, + num_workers=args.dataloader_num_workers, + seed=args.seed, + filter=keyword_filter, + dtype=weight_dtype + ) + datamodule.setup() + + train_dataloader = datamodule.train_dataloader + val_dataloader = datamodule.val_dataloader + + if args.num_class_images != 0: + generate_class_images( + accelerator, + text_encoder, + vae, + unet, + tokenizer, + sample_scheduler, + datamodule.data_train, + args.sample_batch_size, + args.sample_image_size, + args.sample_steps + ) + + if args.find_lr: + lr_scheduler = None + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - save_args(setup.output_dir, args) + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) - train_ti( - setup=setup, - num_train_epochs=args.num_train_epochs, - num_class_images=args.num_class_images, - prior_loss_weight=args.prior_loss_weight, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - adam_beta1=args.adam_beta1, - adam_beta2=args.adam_beta2, - adam_weight_decay=args.adam_weight_decay, - adam_epsilon=args.adam_epsilon, - adam_amsgrad=args.adam_amsgrad, - lr_scheduler=args.lr_scheduler, - lr_min_lr=args.lr_min_lr, - lr_warmup_func=args.lr_warmup_func, - lr_annealing_func=args.lr_annealing_func, - lr_warmup_exp=args.lr_warmup_exp, - lr_annealing_exp=args.lr_annealing_exp, - lr_cycles=args.lr_cycles, - lr_warmup_epochs=args.lr_warmup_epochs, - emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, + if args.use_ema: + ema_embeddings.to(accelerator.device) + + if args.gradient_checkpointing: + unet.train() + else: + unet.eval() + + @contextmanager + def on_train(epoch: int): + try: + tokenizer.train() + yield + finally: + pass + + @contextmanager + def on_eval(): + try: + tokenizer.eval() + + ema_context = ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() + + with ema_context: + yield + finally: + pass + + @torch.no_grad() + def on_after_optimize(lr: float): + text_encoder.text_model.embeddings.normalize( + args.emb_decay_target, + min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) + ) + + if args.use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + def on_log(): + if args.use_ema: + return {"ema_decay": ema_embeddings.decay} + return {} + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + args.num_class_images != 0, + args.prior_loss_weight, + args.seed, + ) + + checkpointer = Checkpointer( + weight_dtype=weight_dtype, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + ema_embeddings=ema_embeddings, + scheduler=sample_scheduler, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=args.global_step, - ) + seed=args.seed + ) + + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + if args.find_lr: + lr_finder = LRFinder( + accelerator=accelerator, + optimizer=optimizer, + model=text_encoder, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + on_train=on_train, + on_eval=on_eval, + on_after_optimize=on_after_optimize, + ) + lr_finder.run(num_epochs=100, end_lr=1e3) + + plt.savefig(basepath.joinpath("lr.png"), dpi=300) + plt.close() + else: + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=text_encoder, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf