From 403f525d0c6900cc6844c0d2f4ecb385fc131969 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 09:40:24 +0100 Subject: Fixed reproducibility, more consistant validation --- train_dreambooth.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index df8b54c..6d9bae8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -319,6 +319,12 @@ def parse_args(): default=1e-08, help="Epsilon value for the Adam optimizer" ) + parser.add_argument( + "--adam_amsgrad", + type=bool, + default=False, + help="Amsgrad value for the Adam optimizer" + ) parser.add_argument( "--mixed_precision", type=str, @@ -642,7 +648,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e-4 + args.learning_rate = 1e-6 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -674,6 +680,7 @@ def main(): betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + amsgrad=args.adam_amsgrad, ) weight_dtype = torch.float32 @@ -730,6 +737,7 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, + seed=args.seed, filter=keyword_filter, collate_fn=collate_fn ) @@ -840,7 +848,7 @@ def main(): def on_eval(): tokenizer.eval() - def loop(batch): + def loop(batch, eval: bool = False): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 @@ -849,8 +857,14 @@ def main(): noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, - (bsz,), device=latents.device) + timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + generator=timesteps_gen, + device=latents.device, + ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -1051,7 +1065,7 @@ def main(): with torch.inference_mode(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loop(batch) + loss, acc, bsz = loop(batch, True) loss = loss.detach_() acc = acc.detach_() -- cgit v1.2.3-54-g00ecf