From 6ecfdb73d150c5a596722ec3234e53f4796fbc78 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 13 Jan 2023 23:09:14 +0100 Subject: Unified training script structure --- train_dreambooth.py | 205 ++++++++++++++++++++-------------------------------- train_ti.py | 9 ++- 2 files changed, 84 insertions(+), 130 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 2145e2b..a1802a0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -12,20 +12,18 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, UNet2DConditionModel import matplotlib.pyplot as plt -from diffusers.training_utils import EMAModel from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import VlpnDataModule, VlpnDataItem -from training.common import loss_step, train_loop, generate_class_images +from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models from training.optimization import get_scheduler from training.lr import LRFinder -from training.util import CheckpointerBase, save_args -from models.clip.embeddings import patch_managed_embeddings +from training.util import CheckpointerBase, EMAModel, save_args from models.clip.tokenizer import MultiCLIPTokenizer logger = get_logger(__name__) @@ -69,7 +67,7 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--placeholder_token", + "--placeholder_tokens", type=str, nargs='*', default=[], @@ -446,20 +444,20 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") - if isinstance(args.placeholder_token, str): - args.placeholder_token = [args.placeholder_token] + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] - if len(args.placeholder_token) == 0: - args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] * len(args.placeholder_token) + args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) if len(args.initializer_token) == 0: raise ValueError("You must specify --initializer_token") - if len(args.placeholder_token) != len(args.initializer_token): - raise ValueError("--placeholder_token and --initializer_token must have the same number of items") + if len(args.placeholder_tokens) != len(args.initializer_token): + raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 @@ -467,8 +465,8 @@ def parse_args(): if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.initializer_token) - if len(args.placeholder_token) != len(args.num_vectors): - raise ValueError("--placeholder_token and --num_vectors must have the same number of items") + if len(args.placeholder_tokens) != len(args.num_vectors): + raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -485,33 +483,18 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, - weight_dtype, - datamodule, - accelerator, - vae, - unet, - ema_unet, - tokenizer, - text_encoder, + weight_dtype: torch.dtype, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + ema_unet: EMAModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, scheduler, - output_dir: Path, - placeholder_token, - placeholder_token_id, - sample_image_size, - sample_batches, - sample_batch_size, - seed, + *args, + **kwargs ): - super().__init__( - datamodule=datamodule, - output_dir=output_dir, - placeholder_token=placeholder_token, - placeholder_token_id=placeholder_token_id, - sample_image_size=sample_image_size, - seed=seed or torch.random.seed(), - sample_batches=sample_batches, - sample_batch_size=sample_batch_size - ) + super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator @@ -606,28 +589,19 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - args.seed = args.seed or (torch.random.seed() >> 32) + if args.seed is None: + args.seed = torch.random.seed() >> 32 + set_seed(args.seed) save_args(basepath, args) - # Load the tokenizer and add the placeholder token as a additional special token - if args.tokenizer_name: - tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) - elif args.pretrained_model_name_or_path: - tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( + args.pretrained_model_name_or_path) + tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) - # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') - checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( - args.pretrained_model_name_or_path, subfolder='scheduler') - ema_unet = None - vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -636,16 +610,6 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - if args.use_ema: - ema_unet = EMAModel( - unet.parameters(), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - - embeddings = patch_managed_embeddings(text_encoder) - if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): @@ -654,24 +618,26 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - if len(args.placeholder_token) != 0: - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = [ - tokenizer.encode(token, add_special_tokens=False) - for token in args.initializer_token - ] - - new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) - embeddings.resize(len(tokenizer)) - - for (new_id, init_ids) in zip(new_ids, initializer_token_ids): - embeddings.add_embed(new_id, init_ids) + placeholder_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors + ) - init_ratios = [f"{len(init_ids)} / {len(new_id)}" for new_id, init_ids in zip(new_ids, initializer_token_ids)] + if len(placeholder_token_ids) != 0: + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") - print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") + if args.use_ema: + ema_unet = EMAModel( + unet.parameters(), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) else: - placeholder_token_id = [] + ema_unet = None vae.requires_grad_(False) @@ -765,8 +731,6 @@ def main(): filter=keyword_filter, dtype=weight_dtype ) - - datamodule.prepare_data() datamodule.setup() train_dataloader = datamodule.train_dataloader @@ -779,7 +743,7 @@ def main(): vae, unet, tokenizer, - checkpoint_scheduler, + sample_scheduler, datamodule.data_train, args.sample_batch_size, args.sample_image_size, @@ -808,12 +772,8 @@ def main(): unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - # Move text_encoder and vae to device vae.to(accelerator.device, dtype=weight_dtype) - # Keep text_encoder and vae in eval mode as we don't train these - vae.eval() - if args.use_ema: ema_unet.to(accelerator.device) @@ -877,17 +837,18 @@ def main(): checkpointer = Checkpointer( weight_dtype=weight_dtype, - datamodule=datamodule, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, accelerator=accelerator, vae=vae, unet=unet, ema_unet=ema_unet, tokenizer=tokenizer, text_encoder=text_encoder, - scheduler=checkpoint_scheduler, + scheduler=sample_scheduler, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, output_dir=basepath, - placeholder_token=args.placeholder_token, - placeholder_token_id=placeholder_token_id, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, @@ -895,23 +856,16 @@ def main(): ) if accelerator.is_main_process: - config = vars(args).copy() - config["initializer_token"] = " ".join(config["initializer_token"]) - config["placeholder_token"] = " ".join(config["placeholder_token"]) - if config["collection"] is not None: - config["collection"] = " ".join(config["collection"]) - if config["exclude_collections"] is not None: - config["exclude_collections"] = " ".join(config["exclude_collections"]) accelerator.init_trackers("dreambooth", config=config) if args.find_lr: lr_finder = LRFinder( - accelerator, - text_encoder, - optimizer, - train_dataloader, - val_dataloader, - loss_step_, + accelerator=accelerator, + optimizer=optimizer, + model=unet, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, @@ -921,29 +875,26 @@ def main(): plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() - - return - - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=unet, - checkpointer=checkpointer, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=0, - gradient_accumulation_steps=args.gradient_accumulation_steps, - num_epochs=args.num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) + else: + train_loop( + accelerator=accelerator, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model=unet, + checkpointer=checkpointer, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, + sample_frequency=args.sample_frequency, + sample_steps=args.sample_steps, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=0, + num_epochs=args.num_train_epochs, + on_log=on_log, + on_train=on_train, + on_after_optimize=on_after_optimize, + on_eval=on_eval + ) if __name__ == "__main__": diff --git a/train_ti.py b/train_ti.py index 61195f6..d2ca7eb 100644 --- a/train_ti.py +++ b/train_ti.py @@ -492,7 +492,7 @@ def parse_args(): class Checkpointer(CheckpointerBase): def __init__( self, - weight_dtype, + weight_dtype: torch.dtype, accelerator: Accelerator, vae: AutoencoderKL, unet: UNet2DConditionModel, @@ -587,7 +587,9 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - args.seed = args.seed or (torch.random.seed() >> 32) + if args.seed is None: + args.seed = torch.random.seed() >> 32 + set_seed(args.seed) save_args(basepath, args) @@ -622,7 +624,8 @@ def main(): num_vectors=args.num_vectors ) - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + if len(placeholder_token_ids) != 0: + print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") if args.use_ema: ema_embeddings = EMAModel( -- cgit v1.2.3-70-g09d2