From bed44095ab99440467c2f302899b970c92baebf8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 10:32:58 +0100 Subject: Better eval generator --- train_dreambooth.py | 8 ++++---- train_ti.py | 12 ++++++------ training/lr.py | 6 +++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d9bae8..5e6e35d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -848,7 +848,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch, eval: bool = False): + def loop(step: int, batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -857,7 +857,7 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, @@ -1008,7 +1008,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(step, batch) accelerator.backward(loss) @@ -1065,7 +1065,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch, True) + loss, acc, bsz = loop(step, batch, True) loss = loss.detach_() acc = acc.detach_() diff --git a/train_ti.py b/train_ti.py index 5d6eafc..6f116c3 100644 --- a/train_ti.py +++ b/train_ti.py @@ -598,7 +598,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e-6 + args.learning_rate = 1e-5 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -799,7 +799,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch, eval: bool = False): + def loop(step: int, batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 @@ -808,7 +808,7 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, @@ -881,7 +881,7 @@ def main(): on_train=on_train, on_eval=on_eval, ) - lr_finder.run(end_lr=1e2) + lr_finder.run(end_lr=1e3) plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() @@ -954,7 +954,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(step, batch) accelerator.backward(loss) @@ -998,7 +998,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch, True) + loss, acc, bsz = loop(step, batch, True) loss = loss.detach_() acc = acc.detach_() diff --git a/training/lr.py b/training/lr.py index a3144ba..c8dc040 100644 --- a/training/lr.py +++ b/training/lr.py @@ -24,7 +24,7 @@ class LRFinder(): optimizer, train_dataloader, val_dataloader, - loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], + loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], on_train: Callable[[], None] = noop, on_eval: Callable[[], None] = noop ): @@ -89,7 +89,7 @@ class LRFinder(): break with self.accelerator.accumulate(self.model): - loss, acc, bsz = self.loss_fn(batch) + loss, acc, bsz = self.loss_fn(step, batch) self.accelerator.backward(loss) @@ -108,7 +108,7 @@ class LRFinder(): if step >= num_val_batches: break - loss, acc, bsz = self.loss_fn(batch, True) + loss, acc, bsz = self.loss_fn(step, batch, True) avg_loss.update(loss.detach_(), bsz) avg_acc.update(acc.detach_(), bsz) -- cgit v1.2.3-70-g09d2