From bed44095ab99440467c2f302899b970c92baebf8 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Wed, 4 Jan 2023 10:32:58 +0100
Subject: Better eval generator

---
 train_dreambooth.py |  8 ++++----
 train_ti.py         | 12 ++++++------
 training/lr.py      |  6 +++---
 3 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6d9bae8..5e6e35d 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -848,7 +848,7 @@ def main():
     def on_eval():
         tokenizer.eval()
 
-    def loop(batch, eval: bool = False):
+    def loop(step: int, batch, eval: bool = False):
         # Convert images to latent space
         latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
         latents = latents * 0.18215
@@ -857,7 +857,7 @@ def main():
         noise = torch.randn_like(latents)
         bsz = latents.shape[0]
         # Sample a random timestep for each image
-        timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None
+        timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None
         timesteps = torch.randint(
             0,
             noise_scheduler.config.num_train_timesteps,
@@ -1008,7 +1008,7 @@ def main():
 
             for step, batch in enumerate(train_dataloader):
                 with accelerator.accumulate(unet):
-                    loss, acc, bsz = loop(batch)
+                    loss, acc, bsz = loop(step, batch)
 
                     accelerator.backward(loss)
 
@@ -1065,7 +1065,7 @@ def main():
 
             with torch.inference_mode():
                 for step, batch in enumerate(val_dataloader):
-                    loss, acc, bsz = loop(batch, True)
+                    loss, acc, bsz = loop(step, batch, True)
 
                     loss = loss.detach_()
                     acc = acc.detach_()
diff --git a/train_ti.py b/train_ti.py
index 5d6eafc..6f116c3 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -598,7 +598,7 @@ def main():
         )
 
     if args.find_lr:
-        args.learning_rate = 1e-6
+        args.learning_rate = 1e-5
 
     # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
     if args.use_8bit_adam:
@@ -799,7 +799,7 @@ def main():
     def on_eval():
         tokenizer.eval()
 
-    def loop(batch, eval: bool = False):
+    def loop(step: int, batch, eval: bool = False):
         # Convert images to latent space
         latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
         latents = latents * 0.18215
@@ -808,7 +808,7 @@ def main():
         noise = torch.randn_like(latents)
         bsz = latents.shape[0]
         # Sample a random timestep for each image
-        timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None
+        timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None
         timesteps = torch.randint(
             0,
             noise_scheduler.config.num_train_timesteps,
@@ -881,7 +881,7 @@ def main():
             on_train=on_train,
             on_eval=on_eval,
         )
-        lr_finder.run(end_lr=1e2)
+        lr_finder.run(end_lr=1e3)
 
         plt.savefig(basepath.joinpath("lr.png"), dpi=300)
         plt.close()
@@ -954,7 +954,7 @@ def main():
 
             for step, batch in enumerate(train_dataloader):
                 with accelerator.accumulate(text_encoder):
-                    loss, acc, bsz = loop(batch)
+                    loss, acc, bsz = loop(step, batch)
 
                     accelerator.backward(loss)
 
@@ -998,7 +998,7 @@ def main():
 
             with torch.inference_mode():
                 for step, batch in enumerate(val_dataloader):
-                    loss, acc, bsz = loop(batch, True)
+                    loss, acc, bsz = loop(step, batch, True)
 
                     loss = loss.detach_()
                     acc = acc.detach_()
diff --git a/training/lr.py b/training/lr.py
index a3144ba..c8dc040 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -24,7 +24,7 @@ class LRFinder():
         optimizer,
         train_dataloader,
         val_dataloader,
-        loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]],
+        loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
         on_train: Callable[[], None] = noop,
         on_eval: Callable[[], None] = noop
     ):
@@ -89,7 +89,7 @@ class LRFinder():
                     break
 
                 with self.accelerator.accumulate(self.model):
-                    loss, acc, bsz = self.loss_fn(batch)
+                    loss, acc, bsz = self.loss_fn(step, batch)
 
                     self.accelerator.backward(loss)
 
@@ -108,7 +108,7 @@ class LRFinder():
                     if step >= num_val_batches:
                         break
 
-                    loss, acc, bsz = self.loss_fn(batch, True)
+                    loss, acc, bsz = self.loss_fn(step, batch, True)
                     avg_loss.update(loss.detach_(), bsz)
                     avg_acc.update(acc.detach_(), bsz)
 
-- 
cgit v1.2.3-70-g09d2