summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-13 17:19:18 +0100
committerVolpeon <git@volpeon.ink>2023-02-13 17:19:18 +0100
commit94b676d91382267e7429bd68362019868affd9d1 (patch)
tree513697739ab25217cbfcff630299d02b1f6e98c8 /train_ti.py
parentIntegrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff)
downloadtextual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz
textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2
textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py
index c79dfa2..171d085 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -143,7 +143,7 @@ def parse_args():
143 parser.add_argument( 143 parser.add_argument(
144 "--num_buckets", 144 "--num_buckets",
145 type=int, 145 type=int,
146 default=4, 146 default=0,
147 help="Number of aspect ratio buckets in either direction.", 147 help="Number of aspect ratio buckets in either direction.",
148 ) 148 )
149 parser.add_argument( 149 parser.add_argument(
@@ -485,6 +485,9 @@ def parse_args():
485 485
486 if len(args.placeholder_tokens) != len(args.train_data_template): 486 if len(args.placeholder_tokens) != len(args.train_data_template):
487 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 487 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
488 else:
489 if isinstance(args.train_data_template, list):
490 raise ValueError("--train_data_template can't be a list in simultaneous mode")
488 491
489 if isinstance(args.collection, str): 492 if isinstance(args.collection, str):
490 args.collection = [args.collection] 493 args.collection = [args.collection]
@@ -503,7 +506,7 @@ def main():
503 506
504 global_step_offset = args.global_step 507 global_step_offset = args.global_step
505 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 508 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
506 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 509 output_dir = Path(args.output_dir)/slugify(args.project)/now
507 output_dir.mkdir(parents=True, exist_ok=True) 510 output_dir.mkdir(parents=True, exist_ok=True)
508 511
509 accelerator = Accelerator( 512 accelerator = Accelerator(
@@ -519,7 +522,7 @@ def main():
519 elif args.mixed_precision == "bf16": 522 elif args.mixed_precision == "bf16":
520 weight_dtype = torch.bfloat16 523 weight_dtype = torch.bfloat16
521 524
522 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 525 logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG)
523 526
524 if args.seed is None: 527 if args.seed is None:
525 args.seed = torch.random.seed() >> 32 528 args.seed = torch.random.seed() >> 32
@@ -570,7 +573,7 @@ def main():
570 else: 573 else:
571 optimizer_class = torch.optim.AdamW 574 optimizer_class = torch.optim.AdamW
572 575
573 checkpoint_output_dir = output_dir.joinpath("checkpoints") 576 checkpoint_output_dir = output_dir/"checkpoints"
574 577
575 trainer = partial( 578 trainer = partial(
576 train, 579 train,
@@ -611,11 +614,11 @@ def main():
611 return 614 return
612 615
613 if len(placeholder_tokens) == 1: 616 if len(placeholder_tokens) == 1:
614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") 617 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}"
615 metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") 618 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png"
616 else: 619 else:
617 sample_output_dir = output_dir.joinpath("samples") 620 sample_output_dir = output_dir/"samples"
618 metrics_output_file = output_dir.joinpath(f"lr.png") 621 metrics_output_file = output_dir/f"lr.png"
619 622
620 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 623 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
621 tokenizer=tokenizer, 624 tokenizer=tokenizer,