From 5821523a524190490a287c5e2aacb6e72cc3a4cf Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 17 Jan 2023 07:20:45 +0100
Subject: Update

---
 training/functional.py          | 19 ++++++++++++++-----
 training/strategy/dreambooth.py | 10 ++++++++--
 training/strategy/ti.py         | 19 +++++++++++++------
 training/util.py                | 11 +++--------
 4 files changed, 38 insertions(+), 21 deletions(-)

(limited to 'training')

diff --git a/training/functional.py b/training/functional.py
index 3d27380..7a3e821 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -39,11 +39,18 @@ class TrainingCallbacks():
     on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
     on_before_optimize: Callable[[int], None] = const()
     on_after_optimize: Callable[[float], None] = const()
+    on_after_epoch: Callable[[float], None] = const()
     on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
     on_sample: Callable[[int], None] = const()
     on_checkpoint: Callable[[int, str], None] = const()
 
 
+@dataclass
+class TrainingStrategy():
+    callbacks: Callable[..., TrainingCallbacks]
+    prepare_unet: bool = False
+
+
 def make_grid(images, rows, cols):
     w, h = images[0].size
     grid = Image.new('RGB', size=(cols*w, rows*h))
@@ -373,6 +380,7 @@ def train_loop(
     on_train = callbacks.on_train
     on_before_optimize = callbacks.on_before_optimize
     on_after_optimize = callbacks.on_after_optimize
+    on_after_epoch = callbacks.on_after_epoch
     on_eval = callbacks.on_eval
     on_sample = callbacks.on_sample
     on_checkpoint = callbacks.on_checkpoint
@@ -434,6 +442,8 @@ def train_loop(
 
             accelerator.wait_for_everyone()
 
+            on_after_epoch(lr_scheduler.get_last_lr()[0])
+
             if val_dataloader is not None:
                 model.eval()
 
@@ -512,8 +522,7 @@ def train(
     val_dataloader: Optional[DataLoader],
     optimizer: torch.optim.Optimizer,
     lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
-    callbacks_fn: Callable[..., TrainingCallbacks],
-    prepare_unet: bool = False,
+    strategy: TrainingStrategy,
     num_train_epochs: int = 100,
     sample_frequency: int = 20,
     checkpoint_frequency: int = 50,
@@ -524,12 +533,12 @@ def train(
 ):
     prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
 
-    if prepare_unet:
+    if strategy.prepare_unet:
         prep.append(unet)
 
     prep = accelerator.prepare(*prep)
 
-    if prepare_unet:
+    if strategy.prepare_unet:
         text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
     else:
         text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
@@ -542,7 +551,7 @@ def train(
         model.requires_grad_(False)
         model.eval()
 
-    callbacks = callbacks_fn(
+    callbacks = strategy.callbacks(
         accelerator=accelerator,
         unet=unet,
         text_encoder=text_encoder,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 93c81cb..bc26ee6 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
 from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
 from models.clip.tokenizer import MultiCLIPTokenizer
 from training.util import EMAModel
-from training.functional import TrainingCallbacks, save_samples
+from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
 
 
-def dreambooth_strategy(
+def dreambooth_strategy_callbacks(
     accelerator: Accelerator,
     unet: UNet2DConditionModel,
     text_encoder: CLIPTextModel,
@@ -185,3 +185,9 @@ def dreambooth_strategy(
         on_checkpoint=on_checkpoint,
         on_sample=on_sample,
     )
+
+
+dreambooth_strategy = TrainingStrategy(
+    callbacks=dreambooth_strategy_callbacks,
+    prepare_unet=True
+)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 00f3529..597abd0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,10 +15,10 @@ from slugify import slugify
 
 from models.clip.tokenizer import MultiCLIPTokenizer
 from training.util import EMAModel
-from training.functional import TrainingCallbacks, save_samples
+from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
 
 
-def textual_inversion_strategy(
+def textual_inversion_strategy_callbacks(
     accelerator: Accelerator,
     unet: UNet2DConditionModel,
     text_encoder: CLIPTextModel,
@@ -119,17 +119,18 @@ def textual_inversion_strategy(
         with ema_context():
             yield
 
-    @torch.no_grad()
     def on_after_optimize(lr: float):
+        if use_ema:
+            ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
+
+    @torch.no_grad()
+    def on_after_epoch(lr: float):
         if use_emb_decay:
             text_encoder.text_model.embeddings.normalize(
                 emb_decay_target,
                 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
             )
 
-        if use_ema:
-            ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
-
     def on_log():
         if use_ema:
             return {"ema_decay": ema_embeddings.decay}
@@ -157,7 +158,13 @@ def textual_inversion_strategy(
         on_train=on_train,
         on_eval=on_eval,
         on_after_optimize=on_after_optimize,
+        on_after_epoch=on_after_epoch,
         on_log=on_log,
         on_checkpoint=on_checkpoint,
         on_sample=on_sample,
     )
+
+
+textual_inversion_strategy = TrainingStrategy(
+    callbacks=textual_inversion_strategy_callbacks,
+)
diff --git a/training/util.py b/training/util.py
index 557b196..237626f 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,18 +1,11 @@
 from pathlib import Path
 import json
 import copy
-from typing import Iterable, Union
+from typing import Iterable, Any
 from contextlib import contextmanager
 
 import torch
 
-from transformers import CLIPTextModel
-from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
-
-from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
-from models.clip.tokenizer import MultiCLIPTokenizer
-from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
-
 
 def save_args(basepath: Path, args, extra={}):
     info = {"args": vars(args)}
@@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}):
 
 
 class AverageMeter:
+    avg: Any
+
     def __init__(self, name=None):
         self.name = name
         self.reset()
-- 
cgit v1.2.3-70-g09d2