From 5b9a3de142e7a645573b4f4a8c1ce9c59746ab08 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sun, 15 Jan 2023 09:25:30 +0100
Subject: Added functional trainer

---
 train_ti.py            | 49 ++++++++++++++++-----------------
 trainer_old/base.py    | 14 +++-------
 training/functional.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 101 insertions(+), 37 deletions(-)

diff --git a/train_ti.py b/train_ti.py
index 78c1b5c..97e4e72 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,7 @@ from slugify import slugify
 from util import load_config, load_embeddings_from_dir
 from data.csv import VlpnDataModule, VlpnDataItem
 from trainer_old.base import Checkpointer
-from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
+from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
 from training.optimization import get_scheduler
 from training.lr import LRFinder
 from training.util import EMAModel, save_args
@@ -703,17 +703,27 @@ def main():
             warmup_epochs=args.lr_warmup_epochs,
         )
 
-    unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
-        unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
-    )
-
-    vae.to(accelerator.device, dtype=weight_dtype)
-
     if args.use_ema:
         ema_embeddings.to(accelerator.device)
 
-    if args.gradient_checkpointing:
-        unet.train()
+    trainer = partial(
+        train,
+        accelerator=accelerator,
+        vae=vae,
+        unet=unet,
+        text_encoder=text_encoder,
+        noise_scheduler=noise_scheduler,
+        train_dataloader=train_dataloader,
+        val_dataloader=val_dataloader,
+        dtype=weight_dtype,
+        seed=args.seed,
+    )
+
+    def on_prepare():
+        text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
+
+        if args.gradient_checkpointing:
+            unet.train()
 
     @contextmanager
     def on_train(epoch: int):
@@ -752,16 +762,6 @@ def main():
             return {"ema_decay": ema_embeddings.decay}
         return {}
 
-    loss_step_ = partial(
-        loss_step,
-        vae,
-        noise_scheduler,
-        unet,
-        text_encoder,
-        args.prior_loss_weight,
-        args.seed,
-    )
-
     checkpointer = TextualInversionCheckpointer(
         dtype=weight_dtype,
         train_dataloader=train_dataloader,
@@ -803,18 +803,15 @@ def main():
         plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
         plt.close()
     else:
-        train_loop(
-            accelerator=accelerator,
+        trainer(
             optimizer=optimizer,
             lr_scheduler=lr_scheduler,
-            model=text_encoder,
-            train_dataloader=train_dataloader,
-            val_dataloader=val_dataloader,
-            loss_step=loss_step_,
+            num_train_epochs=args.num_train_epochs,
             sample_frequency=args.sample_frequency,
             checkpoint_frequency=args.checkpoint_frequency,
             global_step_offset=global_step_offset,
-            num_epochs=args.num_train_epochs,
+            prior_loss_weight=args.prior_loss_weight,
+            on_prepare=on_prepare,
             on_log=on_log,
             on_train=on_train,
             on_after_optimize=on_after_optimize,
diff --git a/trainer_old/base.py b/trainer_old/base.py
index 1f85e71..5903d96 100644
--- a/trainer_old/base.py
+++ b/trainer_old/base.py
@@ -174,19 +174,13 @@ class TrainingStrategy():
 
     @contextmanager
     def on_train(self, epoch: int):
-        try:
-            self.tokenizer.train()
-            yield
-        finally:
-            pass
+        self.tokenizer.train()
+        yield
 
     @contextmanager
     def on_eval(self):
-        try:
-            self.tokenizer.eval()
-            yield
-        finally:
-            pass
+        self.tokenizer.eval()
+        yield
 
     def on_before_optimize(self, epoch: int):
         ...
diff --git a/training/functional.py b/training/functional.py
index c5b514a..1f2ca6d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,6 +1,7 @@
 import math
 from contextlib import _GeneratorContextManager, nullcontext
-from typing import Callable, Any, Tuple, Union
+from typing import Callable, Any, Tuple, Union, Optional
+from functools import partial
 
 import torch
 import torch.nn.functional as F
@@ -376,3 +377,75 @@ def train_loop(
             print("Interrupted")
             on_checkpoint(global_step + global_step_offset, "end")
             accelerator.end_training()
+
+
+def train(
+    accelerator: Accelerator,
+    unet: UNet2DConditionModel,
+    text_encoder: CLIPTextModel,
+    vae: AutoencoderKL,
+    noise_scheduler: DDPMScheduler,
+    train_dataloader: DataLoader,
+    val_dataloader: DataLoader,
+    dtype: torch.dtype,
+    seed: int,
+    optimizer: torch.optim.Optimizer,
+    lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
+    num_train_epochs: int = 100,
+    sample_frequency: int = 20,
+    checkpoint_frequency: int = 50,
+    global_step_offset: int = 0,
+    prior_loss_weight: float = 0,
+    on_prepare: Callable[[], dict[str, Any]] = const({}),
+    on_log: Callable[[], dict[str, Any]] = const({}),
+    on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
+    on_before_optimize: Callable[[int], None] = const(),
+    on_after_optimize: Callable[[float], None] = const(),
+    on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
+    on_sample: Callable[[int], None] = const(),
+    on_checkpoint: Callable[[int, str], None] = const(),
+):
+    unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
+        unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
+    )
+
+    vae.to(accelerator.device, dtype=dtype)
+
+    for model in (unet, text_encoder, vae):
+        model.requires_grad_(False)
+        model.eval()
+
+    on_prepare()
+
+    loss_step_ = partial(
+        loss_step,
+        vae,
+        noise_scheduler,
+        unet,
+        text_encoder,
+        prior_loss_weight,
+        seed,
+    )
+
+    train_loop(
+        accelerator=accelerator,
+        optimizer=optimizer,
+        lr_scheduler=lr_scheduler,
+        model=text_encoder,
+        train_dataloader=train_dataloader,
+        val_dataloader=val_dataloader,
+        loss_step=loss_step_,
+        sample_frequency=sample_frequency,
+        checkpoint_frequency=checkpoint_frequency,
+        global_step_offset=global_step_offset,
+        num_epochs=num_train_epochs,
+        on_log=on_log,
+        on_train=on_train,
+        on_before_optimize=on_before_optimize,
+        on_after_optimize=on_after_optimize,
+        on_eval=on_eval,
+        on_sample=on_sample,
+        on_checkpoint=on_checkpoint,
+    )
+
+    accelerator.free_memory()
-- 
cgit v1.2.3-70-g09d2