From bed44095ab99440467c2f302899b970c92baebf8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 10:32:58 +0100 Subject: Better eval generator --- train_dreambooth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d9bae8..5e6e35d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -848,7 +848,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch, eval: bool = False): + def loop(step: int, batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -857,7 +857,7 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, @@ -1008,7 +1008,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(step, batch) accelerator.backward(loss) @@ -1065,7 +1065,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch, True) + loss, acc, bsz = loop(step, batch, True) loss = loss.detach_() acc = acc.detach_() -- cgit v1.2.3-54-g00ecf