summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
committerVolpeon <git@volpeon.ink>2023-01-01 19:19:52 +0100
commitadc52fb8821a496bc8d78235bf10466b39df03e0 (patch)
tree8a6337a6ac10cbe76c55514ab559c647e69fb1aa /train_ti.py
parentFixed accuracy calc, other improvements (diff)
downloadtextual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.gz
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.tar.bz2
textual-inversion-diff-adc52fb8821a496bc8d78235bf10466b39df03e0.zip
Updates
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py51
1 files changed, 25 insertions, 26 deletions
diff --git a/train_ti.py b/train_ti.py
index 20a3190..775b918 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,5 +1,4 @@
1import argparse 1import argparse
2import itertools
3import math 2import math
4import datetime 3import datetime
5import logging 4import logging
@@ -156,6 +155,12 @@ def parse_args():
156 help="Tag dropout probability.", 155 help="Tag dropout probability.",
157 ) 156 )
158 parser.add_argument( 157 parser.add_argument(
158 "--vector_shuffle",
159 type=str,
160 default="auto",
161 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
162 )
163 parser.add_argument(
159 "--dataloader_num_workers", 164 "--dataloader_num_workers",
160 type=int, 165 type=int,
161 default=0, 166 default=0,
@@ -245,7 +250,7 @@ def parse_args():
245 parser.add_argument( 250 parser.add_argument(
246 "--lr_annealing_exp", 251 "--lr_annealing_exp",
247 type=int, 252 type=int,
248 default=2, 253 default=1,
249 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 254 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
250 ) 255 )
251 parser.add_argument( 256 parser.add_argument(
@@ -502,20 +507,14 @@ def main():
502 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 507 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
503 basepath.mkdir(parents=True, exist_ok=True) 508 basepath.mkdir(parents=True, exist_ok=True)
504 509
505 if args.find_lr: 510 accelerator = Accelerator(
506 accelerator = Accelerator( 511 log_with=LoggerType.TENSORBOARD,
507 gradient_accumulation_steps=args.gradient_accumulation_steps, 512 logging_dir=f"{basepath}",
508 mixed_precision=args.mixed_precision 513 gradient_accumulation_steps=args.gradient_accumulation_steps,
509 ) 514 mixed_precision=args.mixed_precision
510 else: 515 )
511 accelerator = Accelerator(
512 log_with=LoggerType.TENSORBOARD,
513 logging_dir=f"{basepath}",
514 gradient_accumulation_steps=args.gradient_accumulation_steps,
515 mixed_precision=args.mixed_precision
516 )
517 516
518 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 517 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
519 518
520 args.seed = args.seed or (torch.random.seed() >> 32) 519 args.seed = args.seed or (torch.random.seed() >> 32)
521 set_seed(args.seed) 520 set_seed(args.seed)
@@ -534,7 +533,7 @@ def main():
534 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 533 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
535 args.pretrained_model_name_or_path, subfolder='scheduler') 534 args.pretrained_model_name_or_path, subfolder='scheduler')
536 535
537 tokenizer.set_use_vector_shuffle(True) 536 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
538 537
539 vae.enable_slicing() 538 vae.enable_slicing()
540 vae.set_use_memory_efficient_attention_xformers(True) 539 vae.set_use_memory_efficient_attention_xformers(True)
@@ -585,7 +584,7 @@ def main():
585 ) 584 )
586 585
587 if args.find_lr: 586 if args.find_lr:
588 args.learning_rate = 1e3 587 args.learning_rate = 1e2
589 588
590 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 589 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
591 if args.use_8bit_adam: 590 if args.use_8bit_adam:
@@ -830,15 +829,6 @@ def main():
830 829
831 return loss, acc, bsz 830 return loss, acc, bsz
832 831
833 if args.find_lr:
834 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
835 lr_finder.run(min_lr=1e-4)
836
837 plt.savefig(basepath.joinpath("lr.png"))
838 plt.close()
839
840 quit()
841
842 # We need to initialize the trackers we use, and also store our configuration. 832 # We need to initialize the trackers we use, and also store our configuration.
843 # The trackers initializes automatically on the main process. 833 # The trackers initializes automatically on the main process.
844 if accelerator.is_main_process: 834 if accelerator.is_main_process:
@@ -852,6 +842,15 @@ def main():
852 config["exclude_collections"] = " ".join(config["exclude_collections"]) 842 config["exclude_collections"] = " ".join(config["exclude_collections"])
853 accelerator.init_trackers("textual_inversion", config=config) 843 accelerator.init_trackers("textual_inversion", config=config)
854 844
845 if args.find_lr:
846 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop)
847 lr_finder.run(min_lr=1e-4)
848
849 plt.savefig(basepath.joinpath("lr.png"))
850 plt.close()
851
852 quit()
853
855 # Train! 854 # Train!
856 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 855 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
857 856