From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 07:20:45 +0100 Subject: Update --- training/functional.py | 19 ++++++++++++++----- training/strategy/dreambooth.py | 10 ++++++++-- training/strategy/ti.py | 19 +++++++++++++------ training/util.py | 11 +++-------- 4 files changed, 38 insertions(+), 21 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 3d27380..7a3e821 100644 --- a/training/functional.py +++ b/training/functional.py @@ -39,11 +39,18 @@ class TrainingCallbacks(): on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() on_after_optimize: Callable[[float], None] = const() + on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) on_sample: Callable[[int], None] = const() on_checkpoint: Callable[[int, str], None] = const() +@dataclass +class TrainingStrategy(): + callbacks: Callable[..., TrainingCallbacks] + prepare_unet: bool = False + + def make_grid(images, rows, cols): w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) @@ -373,6 +380,7 @@ def train_loop( on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize on_after_optimize = callbacks.on_after_optimize + on_after_epoch = callbacks.on_after_epoch on_eval = callbacks.on_eval on_sample = callbacks.on_sample on_checkpoint = callbacks.on_checkpoint @@ -434,6 +442,8 @@ def train_loop( accelerator.wait_for_everyone() + on_after_epoch(lr_scheduler.get_last_lr()[0]) + if val_dataloader is not None: model.eval() @@ -512,8 +522,7 @@ def train( val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - callbacks_fn: Callable[..., TrainingCallbacks], - prepare_unet: bool = False, + strategy: TrainingStrategy, num_train_epochs: int = 100, sample_frequency: int = 20, checkpoint_frequency: int = 50, @@ -524,12 +533,12 @@ def train( ): prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] - if prepare_unet: + if strategy.prepare_unet: prep.append(unet) prep = accelerator.prepare(*prep) - if prepare_unet: + if strategy.prepare_unet: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep else: text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep @@ -542,7 +551,7 @@ def train( model.requires_grad_(False) model.eval() - callbacks = callbacks_fn( + callbacks = strategy.callbacks( accelerator=accelerator, unet=unet, text_encoder=text_encoder, diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def dreambooth_strategy( +def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -185,3 +185,9 @@ def dreambooth_strategy( on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +dreambooth_strategy = TrainingStrategy( + callbacks=dreambooth_strategy_callbacks, + prepare_unet=True +) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 00f3529..597abd0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -15,10 +15,10 @@ from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def textual_inversion_strategy( +def textual_inversion_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -119,17 +119,18 @@ def textual_inversion_strategy( with ema_context(): yield - @torch.no_grad() def on_after_optimize(lr: float): + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + @torch.no_grad() + def on_after_epoch(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) ) - if use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - def on_log(): if use_ema: return {"ema_decay": ema_embeddings.decay} @@ -157,7 +158,13 @@ def textual_inversion_strategy( on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, + on_after_epoch=on_after_epoch, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +textual_inversion_strategy = TrainingStrategy( + callbacks=textual_inversion_strategy_callbacks, +) diff --git a/training/util.py b/training/util.py index 557b196..237626f 100644 --- a/training/util.py +++ b/training/util.py @@ -1,18 +1,11 @@ from pathlib import Path import json import copy -from typing import Iterable, Union +from typing import Iterable, Any from contextlib import contextmanager import torch -from transformers import CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler - -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from models.clip.tokenizer import MultiCLIPTokenizer -from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings - def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} @@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}): class AverageMeter: + avg: Any + def __init__(self, name=None): self.name = name self.reset() -- cgit v1.2.3-70-g09d2