diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
| commit | adc52fb8821a496bc8d78235bf10466b39df03e0 (patch) | |
| tree | 8a6337a6ac10cbe76c55514ab559c647e69fb1aa /train_ti.py | |
| parent | Fixed accuracy calc, other improvements (diff) | |
| download | textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2 textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip | |
Updates
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 51 |
1 files changed, 25 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py index 20a3190..775b918 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -1,5 +1,4 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | ||
| 3 | import math | 2 | import math |
| 4 | import datetime | 3 | import datetime |
| 5 | import logging | 4 | import logging |
| @@ -156,6 +155,12 @@ def parse_args(): | |||
| 156 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
| 157 | ) | 156 | ) |
| 158 | parser.add_argument( | 157 | parser.add_argument( |
| 158 | "--vector_shuffle", | ||
| 159 | type=str, | ||
| 160 | default="auto", | ||
| 161 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
| 162 | ) | ||
| 163 | parser.add_argument( | ||
| 159 | "--dataloader_num_workers", | 164 | "--dataloader_num_workers", |
| 160 | type=int, | 165 | type=int, |
| 161 | default=0, | 166 | default=0, |
| @@ -245,7 +250,7 @@ def parse_args(): | |||
| 245 | parser.add_argument( | 250 | parser.add_argument( |
| 246 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", |
| 247 | type=int, | 252 | type=int, |
| 248 | default=2, | 253 | default=1, |
| 249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
| 250 | ) | 255 | ) |
| 251 | parser.add_argument( | 256 | parser.add_argument( |
| @@ -502,20 +507,14 @@ def main(): | |||
| 502 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 507 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 503 | basepath.mkdir(parents=True, exist_ok=True) | 508 | basepath.mkdir(parents=True, exist_ok=True) |
| 504 | 509 | ||
| 505 | if args.find_lr: | 510 | accelerator = Accelerator( |
| 506 | accelerator = Accelerator( | 511 | log_with=LoggerType.TENSORBOARD, |
| 507 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 512 | logging_dir=f"{basepath}", |
| 508 | mixed_precision=args.mixed_precision | 513 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 509 | ) | 514 | mixed_precision=args.mixed_precision |
| 510 | else: | 515 | ) |
| 511 | accelerator = Accelerator( | ||
| 512 | log_with=LoggerType.TENSORBOARD, | ||
| 513 | logging_dir=f"{basepath}", | ||
| 514 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 515 | mixed_precision=args.mixed_precision | ||
| 516 | ) | ||
| 517 | 516 | ||
| 518 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 517 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 519 | 518 | ||
| 520 | args.seed = args.seed or (torch.random.seed() >> 32) | 519 | args.seed = args.seed or (torch.random.seed() >> 32) |
| 521 | set_seed(args.seed) | 520 | set_seed(args.seed) |
| @@ -534,7 +533,7 @@ def main(): | |||
| 534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 533 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 535 | args.pretrained_model_name_or_path, subfolder='scheduler') | 534 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 536 | 535 | ||
| 537 | tokenizer.set_use_vector_shuffle(True) | 536 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 538 | 537 | ||
| 539 | vae.enable_slicing() | 538 | vae.enable_slicing() |
| 540 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
| @@ -585,7 +584,7 @@ def main(): | |||
| 585 | ) | 584 | ) |
| 586 | 585 | ||
| 587 | if args.find_lr: | 586 | if args.find_lr: |
| 588 | args.learning_rate = 1e3 | 587 | args.learning_rate = 1e2 |
| 589 | 588 | ||
| 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 589 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
| 591 | if args.use_8bit_adam: | 590 | if args.use_8bit_adam: |
| @@ -830,15 +829,6 @@ def main(): | |||
| 830 | 829 | ||
| 831 | return loss, acc, bsz | 830 | return loss, acc, bsz |
| 832 | 831 | ||
| 833 | if args.find_lr: | ||
| 834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
| 835 | lr_finder.run(min_lr=1e-4) | ||
| 836 | |||
| 837 | plt.savefig(basepath.joinpath("lr.png")) | ||
| 838 | plt.close() | ||
| 839 | |||
| 840 | quit() | ||
| 841 | |||
| 842 | # We need to initialize the trackers we use, and also store our configuration. | 832 | # We need to initialize the trackers we use, and also store our configuration. |
| 843 | # The trackers initializes automatically on the main process. | 833 | # The trackers initializes automatically on the main process. |
| 844 | if accelerator.is_main_process: | 834 | if accelerator.is_main_process: |
| @@ -852,6 +842,15 @@ def main(): | |||
| 852 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 842 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
| 853 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
| 854 | 844 | ||
| 845 | if args.find_lr: | ||
| 846 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
| 847 | lr_finder.run(min_lr=1e-4) | ||
| 848 | |||
| 849 | plt.savefig(basepath.joinpath("lr.png")) | ||
| 850 | plt.close() | ||
| 851 | |||
| 852 | quit() | ||
| 853 | |||
| 855 | # Train! | 854 | # Train! |
| 856 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 855 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| 857 | 856 | ||
