summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 11:36:00 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 11:36:00 +0100
commitb7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch)
tree24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz
textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2
textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip
Fixed accuracy calc, other improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py36
1 files changed, 33 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py
index 19348e5..20a3190 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -225,6 +225,30 @@ def parse_args():
225 help="Number of restart cycles in the lr scheduler." 225 help="Number of restart cycles in the lr scheduler."
226 ) 226 )
227 parser.add_argument( 227 parser.add_argument(
228 "--lr_warmup_func",
229 type=str,
230 default="cos",
231 help='Choose between ["linear", "cos"]'
232 )
233 parser.add_argument(
234 "--lr_warmup_exp",
235 type=int,
236 default=1,
237 help='If lr_warmup_func is "cos", exponent to modify the function'
238 )
239 parser.add_argument(
240 "--lr_annealing_func",
241 type=str,
242 default="cos",
243 help='Choose between ["linear", "half_cos", "cos"]'
244 )
245 parser.add_argument(
246 "--lr_annealing_exp",
247 type=int,
248 default=2,
249 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
250 )
251 parser.add_argument(
228 "--use_8bit_adam", 252 "--use_8bit_adam",
229 action="store_true", 253 action="store_true",
230 help="Whether or not to use 8-bit Adam from bitsandbytes." 254 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -510,6 +534,8 @@ def main():
510 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 534 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
511 args.pretrained_model_name_or_path, subfolder='scheduler') 535 args.pretrained_model_name_or_path, subfolder='scheduler')
512 536
537 tokenizer.set_use_vector_shuffle(True)
538
513 vae.enable_slicing() 539 vae.enable_slicing()
514 vae.set_use_memory_efficient_attention_xformers(True) 540 vae.set_use_memory_efficient_attention_xformers(True)
515 unet.set_use_memory_efficient_attention_xformers(True) 541 unet.set_use_memory_efficient_attention_xformers(True)
@@ -559,7 +585,7 @@ def main():
559 ) 585 )
560 586
561 if args.find_lr: 587 if args.find_lr:
562 args.learning_rate = 1e2 588 args.learning_rate = 1e3
563 589
564 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 590 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
565 if args.use_8bit_adam: 591 if args.use_8bit_adam:
@@ -706,6 +732,10 @@ def main():
706 lr_scheduler = get_one_cycle_schedule( 732 lr_scheduler = get_one_cycle_schedule(
707 optimizer=optimizer, 733 optimizer=optimizer,
708 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 734 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
735 warmup=args.lr_warmup_func,
736 annealing=args.lr_annealing_func,
737 warmup_exp=args.lr_warmup_exp,
738 annealing_exp=args.lr_annealing_exp,
709 ) 739 )
710 elif args.lr_scheduler == "cosine_with_restarts": 740 elif args.lr_scheduler == "cosine_with_restarts":
711 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 741 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
@@ -796,13 +826,13 @@ def main():
796 else: 826 else:
797 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 827 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
798 828
799 acc = (model_pred == latents).float().mean() 829 acc = (model_pred == target).float().mean()
800 830
801 return loss, acc, bsz 831 return loss, acc, bsz
802 832
803 if args.find_lr: 833 if args.find_lr:
804 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 834 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
805 lr_finder.run(min_lr=1e-6, num_train_batches=1) 835 lr_finder.run(min_lr=1e-4)
806 836
807 plt.savefig(basepath.joinpath("lr.png")) 837 plt.savefig(basepath.joinpath("lr.png"))
808 plt.close() 838 plt.close()