diff options
author | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
commit | b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch) | |
tree | 24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /train_ti.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2 textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip |
Fixed accuracy calc, other improvements
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 19348e5..20a3190 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -225,6 +225,30 @@ def parse_args(): | |||
225 | help="Number of restart cycles in the lr scheduler." | 225 | help="Number of restart cycles in the lr scheduler." |
226 | ) | 226 | ) |
227 | parser.add_argument( | 227 | parser.add_argument( |
228 | "--lr_warmup_func", | ||
229 | type=str, | ||
230 | default="cos", | ||
231 | help='Choose between ["linear", "cos"]' | ||
232 | ) | ||
233 | parser.add_argument( | ||
234 | "--lr_warmup_exp", | ||
235 | type=int, | ||
236 | default=1, | ||
237 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
238 | ) | ||
239 | parser.add_argument( | ||
240 | "--lr_annealing_func", | ||
241 | type=str, | ||
242 | default="cos", | ||
243 | help='Choose between ["linear", "half_cos", "cos"]' | ||
244 | ) | ||
245 | parser.add_argument( | ||
246 | "--lr_annealing_exp", | ||
247 | type=int, | ||
248 | default=2, | ||
249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
250 | ) | ||
251 | parser.add_argument( | ||
228 | "--use_8bit_adam", | 252 | "--use_8bit_adam", |
229 | action="store_true", | 253 | action="store_true", |
230 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 254 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -510,6 +534,8 @@ def main(): | |||
510 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
511 | args.pretrained_model_name_or_path, subfolder='scheduler') | 535 | args.pretrained_model_name_or_path, subfolder='scheduler') |
512 | 536 | ||
537 | tokenizer.set_use_vector_shuffle(True) | ||
538 | |||
513 | vae.enable_slicing() | 539 | vae.enable_slicing() |
514 | vae.set_use_memory_efficient_attention_xformers(True) | 540 | vae.set_use_memory_efficient_attention_xformers(True) |
515 | unet.set_use_memory_efficient_attention_xformers(True) | 541 | unet.set_use_memory_efficient_attention_xformers(True) |
@@ -559,7 +585,7 @@ def main(): | |||
559 | ) | 585 | ) |
560 | 586 | ||
561 | if args.find_lr: | 587 | if args.find_lr: |
562 | args.learning_rate = 1e2 | 588 | args.learning_rate = 1e3 |
563 | 589 | ||
564 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
565 | if args.use_8bit_adam: | 591 | if args.use_8bit_adam: |
@@ -706,6 +732,10 @@ def main(): | |||
706 | lr_scheduler = get_one_cycle_schedule( | 732 | lr_scheduler = get_one_cycle_schedule( |
707 | optimizer=optimizer, | 733 | optimizer=optimizer, |
708 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 734 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
735 | warmup=args.lr_warmup_func, | ||
736 | annealing=args.lr_annealing_func, | ||
737 | warmup_exp=args.lr_warmup_exp, | ||
738 | annealing_exp=args.lr_annealing_exp, | ||
709 | ) | 739 | ) |
710 | elif args.lr_scheduler == "cosine_with_restarts": | 740 | elif args.lr_scheduler == "cosine_with_restarts": |
711 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 741 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
@@ -796,13 +826,13 @@ def main(): | |||
796 | else: | 826 | else: |
797 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 827 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
798 | 828 | ||
799 | acc = (model_pred == latents).float().mean() | 829 | acc = (model_pred == target).float().mean() |
800 | 830 | ||
801 | return loss, acc, bsz | 831 | return loss, acc, bsz |
802 | 832 | ||
803 | if args.find_lr: | 833 | if args.find_lr: |
804 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
805 | lr_finder.run(min_lr=1e-6, num_train_batches=1) | 835 | lr_finder.run(min_lr=1e-4) |
806 | 836 | ||
807 | plt.savefig(basepath.joinpath("lr.png")) | 837 | plt.savefig(basepath.joinpath("lr.png")) |
808 | plt.close() | 838 | plt.close() |