From b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 11:36:00 +0100 Subject: Fixed accuracy calc, other improvements --- train_ti.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 19348e5..20a3190 100644 --- a/train_ti.py +++ b/train_ti.py @@ -224,6 +224,30 @@ def parse_args(): default=None, help="Number of restart cycles in the lr scheduler." ) + parser.add_argument( + "--lr_warmup_func", + type=str, + default="cos", + help='Choose between ["linear", "cos"]' + ) + parser.add_argument( + "--lr_warmup_exp", + type=int, + default=1, + help='If lr_warmup_func is "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_annealing_func", + type=str, + default="cos", + help='Choose between ["linear", "half_cos", "cos"]' + ) + parser.add_argument( + "--lr_annealing_exp", + type=int, + default=2, + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -510,6 +534,8 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') + tokenizer.set_use_vector_shuffle(True) + vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -559,7 +585,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e2 + args.learning_rate = 1e3 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -706,6 +732,10 @@ def main(): lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + warmup=args.lr_warmup_func, + annealing=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( @@ -796,13 +826,13 @@ def main(): else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - acc = (model_pred == latents).float().mean() + acc = (model_pred == target).float().mean() return loss, acc, bsz if args.find_lr: lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) - lr_finder.run(min_lr=1e-6, num_train_batches=1) + lr_finder.run(min_lr=1e-4) plt.savefig(basepath.joinpath("lr.png")) plt.close() -- cgit v1.2.3-54-g00ecf