diff options
author | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-01 19:19:52 +0100 |
commit | adc52fb8821a496bc8d78235bf10466b39df03e0 (patch) | |
tree | 8a6337a6ac10cbe76c55514ab559c647e69fb1aa /train_ti.py | |
parent | Fixed accuracy calc, other improvements (diff) | |
download | textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2 textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip |
Updates
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 51 |
1 files changed, 25 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py index 20a3190..775b918 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,5 +1,4 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | ||
3 | import math | 2 | import math |
4 | import datetime | 3 | import datetime |
5 | import logging | 4 | import logging |
@@ -156,6 +155,12 @@ def parse_args(): | |||
156 | help="Tag dropout probability.", | 155 | help="Tag dropout probability.", |
157 | ) | 156 | ) |
158 | parser.add_argument( | 157 | parser.add_argument( |
158 | "--vector_shuffle", | ||
159 | type=str, | ||
160 | default="auto", | ||
161 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
162 | ) | ||
163 | parser.add_argument( | ||
159 | "--dataloader_num_workers", | 164 | "--dataloader_num_workers", |
160 | type=int, | 165 | type=int, |
161 | default=0, | 166 | default=0, |
@@ -245,7 +250,7 @@ def parse_args(): | |||
245 | parser.add_argument( | 250 | parser.add_argument( |
246 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", |
247 | type=int, | 252 | type=int, |
248 | default=2, | 253 | default=1, |
249 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
250 | ) | 255 | ) |
251 | parser.add_argument( | 256 | parser.add_argument( |
@@ -502,20 +507,14 @@ def main(): | |||
502 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 507 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
503 | basepath.mkdir(parents=True, exist_ok=True) | 508 | basepath.mkdir(parents=True, exist_ok=True) |
504 | 509 | ||
505 | if args.find_lr: | 510 | accelerator = Accelerator( |
506 | accelerator = Accelerator( | 511 | log_with=LoggerType.TENSORBOARD, |
507 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 512 | logging_dir=f"{basepath}", |
508 | mixed_precision=args.mixed_precision | 513 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
509 | ) | 514 | mixed_precision=args.mixed_precision |
510 | else: | 515 | ) |
511 | accelerator = Accelerator( | ||
512 | log_with=LoggerType.TENSORBOARD, | ||
513 | logging_dir=f"{basepath}", | ||
514 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
515 | mixed_precision=args.mixed_precision | ||
516 | ) | ||
517 | 516 | ||
518 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 517 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
519 | 518 | ||
520 | args.seed = args.seed or (torch.random.seed() >> 32) | 519 | args.seed = args.seed or (torch.random.seed() >> 32) |
521 | set_seed(args.seed) | 520 | set_seed(args.seed) |
@@ -534,7 +533,7 @@ def main(): | |||
534 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 533 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
535 | args.pretrained_model_name_or_path, subfolder='scheduler') | 534 | args.pretrained_model_name_or_path, subfolder='scheduler') |
536 | 535 | ||
537 | tokenizer.set_use_vector_shuffle(True) | 536 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
538 | 537 | ||
539 | vae.enable_slicing() | 538 | vae.enable_slicing() |
540 | vae.set_use_memory_efficient_attention_xformers(True) | 539 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -585,7 +584,7 @@ def main(): | |||
585 | ) | 584 | ) |
586 | 585 | ||
587 | if args.find_lr: | 586 | if args.find_lr: |
588 | args.learning_rate = 1e3 | 587 | args.learning_rate = 1e2 |
589 | 588 | ||
590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 589 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
591 | if args.use_8bit_adam: | 590 | if args.use_8bit_adam: |
@@ -830,15 +829,6 @@ def main(): | |||
830 | 829 | ||
831 | return loss, acc, bsz | 830 | return loss, acc, bsz |
832 | 831 | ||
833 | if args.find_lr: | ||
834 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
835 | lr_finder.run(min_lr=1e-4) | ||
836 | |||
837 | plt.savefig(basepath.joinpath("lr.png")) | ||
838 | plt.close() | ||
839 | |||
840 | quit() | ||
841 | |||
842 | # We need to initialize the trackers we use, and also store our configuration. | 832 | # We need to initialize the trackers we use, and also store our configuration. |
843 | # The trackers initializes automatically on the main process. | 833 | # The trackers initializes automatically on the main process. |
844 | if accelerator.is_main_process: | 834 | if accelerator.is_main_process: |
@@ -852,6 +842,15 @@ def main(): | |||
852 | config["exclude_collections"] = " ".join(config["exclude_collections"]) | 842 | config["exclude_collections"] = " ".join(config["exclude_collections"]) |
853 | accelerator.init_trackers("textual_inversion", config=config) | 843 | accelerator.init_trackers("textual_inversion", config=config) |
854 | 844 | ||
845 | if args.find_lr: | ||
846 | lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) | ||
847 | lr_finder.run(min_lr=1e-4) | ||
848 | |||
849 | plt.savefig(basepath.joinpath("lr.png")) | ||
850 | plt.close() | ||
851 | |||
852 | quit() | ||
853 | |||
855 | # Train! | 854 | # Train! |
856 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 855 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
857 | 856 | ||