From 3c6ccadd3c12c54a1fa2280bce505a2dd511958a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 07:27:45 +0100 Subject: Implemented extended Dreambooth training --- train_dreambooth.py | 484 +++++++++++++++++----------------------------------- 1 file changed, 155 insertions(+), 329 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 71bad7e..944256c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -1,10 +1,8 @@ import argparse -import itertools import datetime import logging from pathlib import Path from functools import partial -from contextlib import contextmanager, nullcontext import torch import torch.utils.checkpoint @@ -12,18 +10,15 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, UNet2DConditionModel -import matplotlib.pyplot as plt -from transformers import CLIPTextModel from slugify import slugify from util import load_config, load_embeddings_from_dir -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import VlpnDataModule, VlpnDataItem +from data.csv import VlpnDataModule, keyword_filter +from training.functional import train, generate_class_images, add_placeholder_tokens, get_models +from training.strategy.ti import textual_inversion_strategy +from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler -from training.lr import LRFinder -from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models -from models.clip.tokenizer import MultiCLIPTokenizer +from training.util import save_args logger = get_logger(__name__) @@ -73,7 +68,7 @@ def parse_args(): help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--initializer_token", + "--initializer_tokens", type=str, nargs='*', default=[], @@ -151,7 +146,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=1, + default=0, help="How many class images to generate." ) parser.add_argument( @@ -437,23 +432,23 @@ def parse_args(): if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] - if len(args.placeholder_tokens) == 0: - args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_token)] + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - if isinstance(args.initializer_token, str): - args.initializer_token = [args.initializer_token] * len(args.placeholder_tokens) + if len(args.initializer_tokens) == 0: + raise ValueError("You must specify --initializer_tokens") - if len(args.initializer_token) == 0: - raise ValueError("You must specify --initializer_token") + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] - if len(args.placeholder_tokens) != len(args.initializer_token): - raise ValueError("--placeholder_tokens and --initializer_token must have the same number of items") + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") if args.num_vectors is None: args.num_vectors = 1 if isinstance(args.num_vectors, int): - args.num_vectors = [args.num_vectors] * len(args.initializer_token) + args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") @@ -470,102 +465,9 @@ def parse_args(): return args -class Checkpointer(CheckpointerBase): - def __init__( - self, - weight_dtype: torch.dtype, - accelerator: Accelerator, - vae: AutoencoderKL, - unet: UNet2DConditionModel, - ema_unet: EMAModel, - tokenizer: MultiCLIPTokenizer, - text_encoder: CLIPTextModel, - scheduler, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.weight_dtype = weight_dtype - self.accelerator = accelerator - self.vae = vae - self.unet = unet - self.ema_unet = ema_unet - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.scheduler = scheduler - - @torch.no_grad() - def save_model(self): - print("Saving model...") - - unet = self.accelerator.unwrap_model(self.unet) - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() - - with ema_context: - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ) - pipeline.save_pretrained(self.output_dir.joinpath("model")) - - del unet - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - @torch.no_grad() - def save_samples(self, step): - unet = self.accelerator.unwrap_model(self.unet) - text_encoder = self.accelerator.unwrap_model(self.text_encoder) - - ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() - - with ema_context: - orig_unet_dtype = unet.dtype - orig_text_encoder_dtype = text_encoder.dtype - - unet.to(dtype=self.weight_dtype) - text_encoder.to(dtype=self.weight_dtype) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=self.vae, - unet=unet, - tokenizer=self.tokenizer, - scheduler=self.scheduler, - ).to(self.accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - super().save_samples(pipeline, step) - - unet.to(dtype=orig_unet_dtype) - text_encoder.to(dtype=orig_text_encoder_dtype) - - del unet - del text_encoder - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - def main(): args = parse_args() - if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: - raise ValueError( - "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." - ) - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) output_dir.mkdir(parents=True, exist_ok=True) @@ -621,41 +523,12 @@ def main(): placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") - if args.use_ema: - ema_unet = EMAModel( - unet.parameters(), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - else: - ema_unet = None - - vae.requires_grad_(False) - - if args.train_text_encoder: - print(f"Training entire text encoder.") - - embeddings.persist() - text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) - else: - print(f"Training added text embeddings") - - text_encoder.text_model.encoder.requires_grad_(False) - text_encoder.text_model.final_layer_norm.requires_grad_(False) - text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) - text_encoder.text_model.embeddings.token_embedding.requires_grad_(False) - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - if args.find_lr: - args.learning_rate = 1e-6 - - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -666,41 +539,30 @@ def main(): else: optimizer_class = torch.optim.AdamW - if args.train_text_encoder: - text_encoder_params_to_optimize = text_encoder.parameters() - else: - text_encoder_params_to_optimize = text_encoder.text_model.embeddings.temp_token_embedding.parameters() - - # Initialize the optimizer - optimizer = optimizer_class( - [ - { - 'params': unet.parameters(), - }, - { - 'params': text_encoder_params_to_optimize, - } - ], - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - amsgrad=args.adam_amsgrad, - ) - weight_dtype = torch.float32 if args.mixed_precision == "fp16": weight_dtype = torch.float16 elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - def keyword_filter(item: VlpnDataItem): - cond3 = args.collection is None or args.collection in item.collection - cond4 = args.exclude_collections is None or not any( - collection in item.collection - for collection in args.exclude_collections - ) - return cond3 and cond4 + trainer = partial( + train, + accelerator=accelerator, + unet=unet, + text_encoder=text_encoder, + vae=vae, + noise_scheduler=noise_scheduler, + dtype=weight_dtype, + seed=args.seed, + callbacks_fn=textual_inversion_strategy + ) + + # Initial TI + + print("Phase 1: Textual Inversion") + + cur_dir = output_dir.joinpath("1-ti") + cur_dir.mkdir(parents=True, exist_ok=True) datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -709,182 +571,146 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, - num_buckets=args.num_buckets, - progressive_buckets=args.progressive_buckets, - bucket_step_size=args.bucket_step_size, - bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, - filter=keyword_filter, + filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - - if args.num_class_images != 0: - generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - sample_scheduler, - datamodule.data_train, - args.sample_batch_size, - args.sample_image_size, - args.sample_steps - ) - - if args.find_lr: - lr_scheduler = None - else: - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - min_lr=args.lr_min_lr, - warmup_func=args.lr_warmup_func, - annealing_func=args.lr_annealing_func, - warmup_exp=args.lr_warmup_exp, - annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, - train_epochs=args.num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, - num_training_steps_per_epoch=len(train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps - ) + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=2e-1, + weight_decay=0.0, + ) - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + lr_scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + train_epochs=30, + warmup_epochs=10, + ) + + trainer( + project="textual_inversion", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=30, + sample_frequency=5, + checkpoint_frequency=9999999, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + output_dir=cur_dir, + placeholder_tokens=args.placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + learning_rate=2e-1, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=True, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, ) - vae.to(accelerator.device, dtype=weight_dtype) + # Dreambooth - if args.use_ema: - ema_unet.to(accelerator.device) + print("Phase 2: Dreambooth") - @contextmanager - def on_train(epoch: int): - try: - tokenizer.train() + cur_dir = output_dir.joinpath("2db") + cur_dir.mkdir(parents=True, exist_ok=True) - if epoch < args.train_text_encoder_epochs: - text_encoder.train() - elif epoch == args.train_text_encoder_epochs: - text_encoder.requires_grad_(False) + args.seed = (args.seed + 28635) >> 32 - yield - finally: - pass + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, + bucket_step_size=args.bucket_step_size, + bucket_max_pixels=args.bucket_max_pixels, + dropout=args.tag_dropout, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, + valid_set_size=args.valid_set_size, + valid_set_repeat=args.valid_set_repeat, + seed=args.seed, + filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + dtype=weight_dtype + ) + datamodule.setup() - @contextmanager - def on_eval(): - try: - tokenizer.eval() - text_encoder.eval() - - ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() - - with ema_context: - yield - finally: - pass - - def on_before_optimize(epoch: int): - if accelerator.sync_gradients: - params_to_clip = [unet.parameters()] - if args.train_text_encoder and epoch < args.train_text_encoder_epochs: - params_to_clip.append(text_encoder.parameters()) - accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), args.max_grad_norm) - - @torch.no_grad() - def on_after_optimize(lr: float): - if not args.train_text_encoder: - text_encoder.text_model.embeddings.normalize( - args.decay_target, - min(1.0, max(0.0, args.decay_factor * ((lr - args.decay_start) / (args.learning_rate - args.decay_start)))) - ) - - def on_log(): - if args.use_ema: - return {"ema_decay": ema_unet.decay} - return {} - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - args.prior_loss_weight, - args.seed, - ) - - checkpointer = Checkpointer( - weight_dtype=weight_dtype, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - accelerator=accelerator, - vae=vae, - unet=unet, - ema_unet=ema_unet, + optimizer = optimizer_class( + [ + { + 'params': unet.parameters(), + }, + { + 'params': text_encoder.parameters(), + } + ], + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + min_lr=args.lr_min_lr, + warmup_func=args.lr_warmup_func, + annealing_func=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, + cycles=args.lr_cycles, + train_epochs=args.num_train_epochs, + warmup_epochs=args.lr_warmup_epochs, + ) + + trainer( + project="dreambooth", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- tokenizer=tokenizer, - text_encoder=text_encoder, - scheduler=sample_scheduler, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - output_dir=output_dir, - sample_steps=args.sample_steps, - sample_image_size=args.sample_image_size, + sample_scheduler=sample_scheduler, + output_dir=cur_dir, + gradient_checkpointing=args.gradient_checkpointing, + train_text_encoder_epochs=args.train_text_encoder_epochs, + max_grad_norm=args.max_grad_norm, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - - if accelerator.is_main_process: - accelerator.init_trackers("dreambooth", config=config) - - if args.find_lr: - lr_finder = LRFinder( - accelerator=accelerator, - optimizer=optimizer, - model=unet, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - on_train=on_train, - on_eval=on_eval, - on_before_optimize=on_before_optimize, - on_after_optimize=on_after_optimize, - ) - lr_finder.run(num_epochs=100, end_lr=1e2) - - plt.savefig(output_dir.joinpath("lr.png"), dpi=300) - plt.close() - else: - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=unet, - checkpointer=checkpointer, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=0, - num_epochs=args.num_train_epochs, - on_log=on_log, - on_train=on_train, - on_after_optimize=on_after_optimize, - on_eval=on_eval - ) + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) if __name__ == "__main__": -- cgit v1.2.3-54-g00ecf