From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 07:20:45 +0100 Subject: Update --- training/strategy/dreambooth.py | 10 ++++++++-- training/strategy/ti.py | 19 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def dreambooth_strategy( +def dreambooth_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -185,3 +185,9 @@ def dreambooth_strategy( on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +dreambooth_strategy = TrainingStrategy( + callbacks=dreambooth_strategy_callbacks, + prepare_unet=True +) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 00f3529..597abd0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -15,10 +15,10 @@ from slugify import slugify from models.clip.tokenizer import MultiCLIPTokenizer from training.util import EMAModel -from training.functional import TrainingCallbacks, save_samples +from training.functional import TrainingStrategy, TrainingCallbacks, save_samples -def textual_inversion_strategy( +def textual_inversion_strategy_callbacks( accelerator: Accelerator, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, @@ -119,17 +119,18 @@ def textual_inversion_strategy( with ema_context(): yield - @torch.no_grad() def on_after_optimize(lr: float): + if use_ema: + ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + + @torch.no_grad() + def on_after_epoch(lr: float): if use_emb_decay: text_encoder.text_model.embeddings.normalize( emb_decay_target, min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) ) - if use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - def on_log(): if use_ema: return {"ema_decay": ema_embeddings.decay} @@ -157,7 +158,13 @@ def textual_inversion_strategy( on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, + on_after_epoch=on_after_epoch, on_log=on_log, on_checkpoint=on_checkpoint, on_sample=on_sample, ) + + +textual_inversion_strategy = TrainingStrategy( + callbacks=textual_inversion_strategy_callbacks, +) -- cgit v1.2.3-54-g00ecf