diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-01 11:36:00 +0100 |
| commit | b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 (patch) | |
| tree | 24fd6d9f3a92ce9f5cccd5cdd914edfee665af71 /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.gz textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.tar.bz2 textual-inversion-diff-b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5.zip | |
Fixed accuracy calc, other improvements
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/train_ti.py b/train_ti.py index 19348e5..20a3190 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -225,6 +225,30 @@ def parse_args(): | |||
| 225 | help="Number of restart cycles in the lr scheduler." | 225 | help="Number of restart cycles in the lr scheduler." |
| 226 | ) | 226 | ) |
| 227 | parser.add_argument( | 227 | parser.add_argument( |
| 228 | "--lr_warmup_func", | ||
| 229 | type=str, | ||
| 230 | default="cos", | ||
| 231 | help='Choose between ["linear", "cos"]' | ||
| 232 | ) | ||
| 233 | parser.add_argument( | ||
| 234 | "--lr_warmup_exp", | ||
| 235 | type=int, | ||
| 236 | default=1, | ||
| 237 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 238 | ) | ||
| 239 | parser.add_argument( | ||
| 240 | "--lr_annealing_func", | ||
| 241 | type=str, | ||
| 242 | default="cos", | ||
| 243 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 244 | ) | ||
| 245 | parser.add_argument( | ||
| 246 | "--lr_annealing_exp", | ||
| 247 | type=int, | ||
| 248 | default=2, | ||
| 249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 250 | ) | ||
| 251 | parser.add_argument( | ||
| 228 | "--use_8bit_adam", | 252 | "--use_8bit_adam", |
| 229 | action="store_true", | 253 | action="store_true", |
| 230 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 254 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -510,6 +534,8 @@ def main(): | |||
| 510 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 511 | args.pretrained_model_name_or_path, subfolder='scheduler') | 535 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 512 | 536 | ||
| 537 | tokenizer.set_use_vector_shuffle(True) | ||
| 538 | |||
| 513 | vae.enable_slicing() | 539 | vae.enable_slicing() |
| 514 | vae.set_use_memory_efficient_attention_xformers(True) | 540 | vae.set_use_memory_efficient_attention_xformers(True) |
| 515 | unet.set_use_memory_efficient_attention_xformers(True) | 541 | unet.set_use_memory_efficient_attention_xformers(True) |
| @@ -559,7 +585,7 @@ def main(): | |||
| 559 | ) | 585 | ) |
| 560 | 586 | ||
| 561 | if args.find_lr: | 587 | if args.find_lr: |
| 562 | args.learning_rate = 1e2 | 588 | args.learning_rate = 1e3 |
| 563 | 589 | ||
| 564 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 565 | if args.use_8bit_adam: | 591 | if args.use_8bit_adam: |
| @@ -706,6 +732,10 @@ def main(): | |||
| 706 | lr_scheduler = get_one_cycle_schedule( | 732 | lr_scheduler = get_one_cycle_schedule( |
| 707 | optimizer=optimizer, | 733 | optimizer=optimizer, |
| 708 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 734 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 735 | warmup=args.lr_warmup_func, | ||
| 736 | annealing=args.lr_annealing_func, | ||
| 737 | warmup_exp=args.lr_warmup_exp, | ||
| 738 | annealing_exp=args.lr_annealing_exp, | ||
| 709 | ) | 739 | ) |
| 710 | elif args.lr_scheduler == "cosine_with_restarts": | 740 | elif args.lr_scheduler == "cosine_with_restarts": |
| 711 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 741 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| @@ -796,13 +826,13 @@ def main(): | |||
| 796 | else: | 826 | else: |
| 797 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | 827 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
| 798 | 828 | ||
| 799 | acc = (model_pred == latents).float().mean() | 829 | acc = (model_pred == target).float().mean() |
| 800 | 830 | ||
| 801 | return loss, acc, bsz | 831 | return loss, acc, bsz |
| 802 | 832 | ||
| 803 | if args.find_lr: | 833 | if args.find_lr: |
| 804 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | 834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) |
| 805 | lr_finder.run(min_lr=1e-6, num_train_batches=1) | 835 | lr_finder.run(min_lr=1e-4) |
| 806 | 836 | ||
| 807 | plt.savefig(basepath.joinpath("lr.png")) | 837 | plt.savefig(basepath.joinpath("lr.png")) |
| 808 | plt.close() | 838 | plt.close() |
