From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 07:20:45 +0100 Subject: Update --- train_dreambooth.py | 5 +- train_ti.py | 113 +++++++++++++++++++++++----------------- training/functional.py | 19 +++++-- training/strategy/dreambooth.py | 10 +++- training/strategy/ti.py | 19 ++++--- training/util.py | 11 ++-- 6 files changed, 104 insertions(+), 73 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index d722e68..48bdcf8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -14,8 +14,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, generate_class_images, add_placeholder_tokens, get_models -from training.strategy.ti import textual_inversion_strategy +from training.functional import train, get_models from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler from training.util import save_args @@ -610,7 +609,7 @@ def main(): ) trainer( - callbacks_fn=dreambooth_strategy, + strategy=dreambooth_strategy, project="dreambooth", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, diff --git a/train_ti.py b/train_ti.py index e7aeb23..0891c49 100644 --- a/train_ti.py +++ b/train_ti.py @@ -14,7 +14,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.util import save_args @@ -78,6 +78,10 @@ def parse_args(): nargs='*', help="Number of vectors per embedding." ) + parser.add_argument( + "--simultaneous", + action="store_true", + ) parser.add_argument( "--num_class_images", type=int, @@ -474,11 +478,12 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") - if isinstance(args.train_data_template, str): - args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) + if not args.simultaneous: + if isinstance(args.train_data_template, str): + args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) - if len(args.placeholder_tokens) != len(args.train_data_template): - raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + if len(args.placeholder_tokens) != len(args.train_data_template): + raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") if isinstance(args.collection, str): args.collection = [args.collection] @@ -560,6 +565,8 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + checkpoint_output_dir = output_dir.joinpath("checkpoints") + trainer = partial( train, accelerator=accelerator, @@ -569,30 +576,50 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, - callbacks_fn=textual_inversion_strategy - ) - - checkpoint_output_dir = output_dir.joinpath("checkpoints") - - for i, placeholder_token, initializer_token, num_vectors, data_template in zip( - range(len(args.placeholder_tokens)), - args.placeholder_tokens, - args.initializer_tokens, - args.num_vectors, - args.train_data_template - ): - sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + strategy=textual_inversion_strategy, + num_train_epochs=args.num_train_epochs, + sample_frequency=args.sample_frequency, + checkpoint_frequency=args.checkpoint_frequency, + global_step_offset=global_step_offset, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + checkpoint_output_dir=checkpoint_output_dir, + learning_rate=args.learning_rate, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=args.use_emb_decay, + emb_decay_target=args.emb_decay_target, + emb_decay_factor=args.emb_decay_factor, + emb_decay_start=args.emb_decay_start, + use_ema=args.use_ema, + ema_inv_gamma=args.ema_inv_gamma, + ema_power=args.ema_power, + ema_max_decay=args.ema_max_decay, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): + if len(placeholder_tokens) == 1: + sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") + else: + sample_output_dir = output_dir.joinpath("samples") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, - placeholder_tokens=[placeholder_token], - initializer_tokens=[initializer_token], - num_vectors=[num_vectors] + placeholder_tokens=placeholder_tokens, + initializer_tokens=initializer_tokens, + num_vectors=num_vectors ) - print( - f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") + stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) + + print(f"{i + 1}: {stats})") datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -612,7 +639,7 @@ def main(): train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, - filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), + filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -647,36 +674,24 @@ def main(): val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, - sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, - global_step_offset=global_step_offset, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, sample_output_dir=sample_output_dir, - checkpoint_output_dir=checkpoint_output_dir, - placeholder_tokens=[placeholder_token], + placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, - learning_rate=args.learning_rate, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=args.use_emb_decay, - emb_decay_target=args.emb_decay_target, - emb_decay_factor=args.emb_decay_factor, - emb_decay_start=args.emb_decay_start, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, ) - embeddings.persist() + if args.simultaneous: + run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) + else: + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( + range(len(args.placeholder_tokens)), + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.train_data_template + ): + run(i, [placeholder_token], [initializer_token], [num_vectors], data_template) + embeddings.persist() if __name__ == "__main__": diff --git a/training/functional.py b/training/functional.py index 3d27380..7a3e821 100644 --- a/training/functional.py +++ b/training/functional.py @@ -39,11 +39,18 @@ class TrainingCallbacks(): on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() on_after_optimize: Callable[[float], None] = const() + on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() on_checkpoint: Callable[[int, str], None] = const() +@dataclass +class TrainingStrategy(): + callbacks: Callable[..., TrainingCallbacks] + prepare_unet: bool = False + + def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) @@ -373,6 +380,7 @@ def train_loop( on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize + on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint @@ -434,6 +442,8 @@ def train_loop( accelerator.wait_for_everyone() + on_after_epoch(lr_scheduler.get_last_lr()[0]) + if val_dataloader is not None: model.eval() @@ -512,8 +522,7 @@ def train( val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - callbacks_fn: Callable[..., TrainingCallbacks], - prepare_unet: bool = False, + strategy: TrainingStrategy, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, @@ -524,12 +533,12 @@ def train( ): prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] - if prepare_unet: + if strategy.prepare_unet: prep.append(unet) prep = accelerator.prepare(*prep) - if prepare_unet: + if strategy.prepare_unet: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep else: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep @@ -542,7 +551,7 @@ def train( model.requires_grad_(False) model.eval() - callbacks = callbacks_fn( + callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def dreambooth_strategy( +def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -185,3 +185,9 @@ def dreambooth_strategy( on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +dreambooth_strategy = TrainingStrategy( + callbacks=dreambooth_strategy_callbacks, + prepare_unet=True +) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 00f3529..597abd0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -15,10 +15,10 @@ from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def textual_inversion_strategy( +def textual_inversion_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -119,17 +119,18 @@ def textual_inversion_strategy( with ema_context(): yield - @torch.no_grad() def on_after_optimize(lr: float): + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + @torch.no_grad() + def on_after_epoch(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) ) - if use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - def on_log(): if use_ema: return {"ema_decay": ema_embeddings.decay} @@ -157,7 +158,13 @@ def textual_inversion_strategy( on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, + on_after_epoch=on_after_epoch, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +textual_inversion_strategy = TrainingStrategy( + callbacks=textual_inversion_strategy_callbacks, +) diff --git a/training/util.py b/training/util.py index 557b196..237626f 100644 --- a/training/util.py +++ b/training/util.py @@ -1,18 +1,11 @@ from pathlib import Path import json import copy -from typing import Iterable, Union +from typing import Iterable, Any from contextlib import contextmanager import torch -from transformers import CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler - -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.tokenizer import MultiCLIPTokenizer -from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings - def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} @@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}): class AverageMeter: + avg: Any + def __init__(self, name=None): self.name = name self.reset() -- cgit v1.2.3-70-g09d2