diff options
author | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
commit | 94b676d91382267e7429bd68362019868affd9d1 (patch) | |
tree | 513697739ab25217cbfcff630299d02b1f6e98c8 /train_ti.py | |
parent | Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff) | |
download | textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2 textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py index c79dfa2..171d085 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -143,7 +143,7 @@ def parse_args(): | |||
143 | parser.add_argument( | 143 | parser.add_argument( |
144 | "--num_buckets", | 144 | "--num_buckets", |
145 | type=int, | 145 | type=int, |
146 | default=4, | 146 | default=0, |
147 | help="Number of aspect ratio buckets in either direction.", | 147 | help="Number of aspect ratio buckets in either direction.", |
148 | ) | 148 | ) |
149 | parser.add_argument( | 149 | parser.add_argument( |
@@ -485,6 +485,9 @@ def parse_args(): | |||
485 | 485 | ||
486 | if len(args.placeholder_tokens) != len(args.train_data_template): | 486 | if len(args.placeholder_tokens) != len(args.train_data_template): |
487 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 487 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
488 | else: | ||
489 | if isinstance(args.train_data_template, list): | ||
490 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | ||
488 | 491 | ||
489 | if isinstance(args.collection, str): | 492 | if isinstance(args.collection, str): |
490 | args.collection = [args.collection] | 493 | args.collection = [args.collection] |
@@ -503,7 +506,7 @@ def main(): | |||
503 | 506 | ||
504 | global_step_offset = args.global_step | 507 | global_step_offset = args.global_step |
505 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 508 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
506 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 509 | output_dir = Path(args.output_dir)/slugify(args.project)/now |
507 | output_dir.mkdir(parents=True, exist_ok=True) | 510 | output_dir.mkdir(parents=True, exist_ok=True) |
508 | 511 | ||
509 | accelerator = Accelerator( | 512 | accelerator = Accelerator( |
@@ -519,7 +522,7 @@ def main(): | |||
519 | elif args.mixed_precision == "bf16": | 522 | elif args.mixed_precision == "bf16": |
520 | weight_dtype = torch.bfloat16 | 523 | weight_dtype = torch.bfloat16 |
521 | 524 | ||
522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 525 | logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) |
523 | 526 | ||
524 | if args.seed is None: | 527 | if args.seed is None: |
525 | args.seed = torch.random.seed() >> 32 | 528 | args.seed = torch.random.seed() >> 32 |
@@ -570,7 +573,7 @@ def main(): | |||
570 | else: | 573 | else: |
571 | optimizer_class = torch.optim.AdamW | 574 | optimizer_class = torch.optim.AdamW |
572 | 575 | ||
573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 576 | checkpoint_output_dir = output_dir/"checkpoints" |
574 | 577 | ||
575 | trainer = partial( | 578 | trainer = partial( |
576 | train, | 579 | train, |
@@ -611,11 +614,11 @@ def main(): | |||
611 | return | 614 | return |
612 | 615 | ||
613 | if len(placeholder_tokens) == 1: | 616 | if len(placeholder_tokens) == 1: |
614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 617 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" |
615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | 618 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" |
616 | else: | 619 | else: |
617 | sample_output_dir = output_dir.joinpath("samples") | 620 | sample_output_dir = output_dir/"samples" |
618 | metrics_output_file = output_dir.joinpath(f"lr.png") | 621 | metrics_output_file = output_dir/f"lr.png" |
619 | 622 | ||
620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 623 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
621 | tokenizer=tokenizer, | 624 | tokenizer=tokenizer, |