From bed44095ab99440467c2f302899b970c92baebf8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 10:32:58 +0100 Subject: Better eval generator --- train_ti.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 5d6eafc..6f116c3 100644 --- a/train_ti.py +++ b/train_ti.py @@ -598,7 +598,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e-6 + args.learning_rate = 1e-5 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -799,7 +799,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch, eval: bool = False): + def loop(step: int, batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 @@ -808,7 +808,7 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, @@ -881,7 +881,7 @@ def main(): on_train=on_train, on_eval=on_eval, ) - lr_finder.run(end_lr=1e2) + lr_finder.run(end_lr=1e3) plt.savefig(basepath.joinpath("lr.png"), dpi=300) plt.close() @@ -954,7 +954,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(step, batch) accelerator.backward(loss) @@ -998,7 +998,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch, True) + loss, acc, bsz = loop(step, batch, True) loss = loss.detach_() acc = acc.detach_() -- cgit v1.2.3-54-g00ecf