diff options
| author | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-02-13 17:19:18 +0100 |
| commit | 94b676d91382267e7429bd68362019868affd9d1 (patch) | |
| tree | 513697739ab25217cbfcff630299d02b1f6e98c8 /train_ti.py | |
| parent | Integrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff) | |
| download | textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2 textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/train_ti.py b/train_ti.py index c79dfa2..171d085 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -143,7 +143,7 @@ def parse_args(): | |||
| 143 | parser.add_argument( | 143 | parser.add_argument( |
| 144 | "--num_buckets", | 144 | "--num_buckets", |
| 145 | type=int, | 145 | type=int, |
| 146 | default=4, | 146 | default=0, |
| 147 | help="Number of aspect ratio buckets in either direction.", | 147 | help="Number of aspect ratio buckets in either direction.", |
| 148 | ) | 148 | ) |
| 149 | parser.add_argument( | 149 | parser.add_argument( |
| @@ -485,6 +485,9 @@ def parse_args(): | |||
| 485 | 485 | ||
| 486 | if len(args.placeholder_tokens) != len(args.train_data_template): | 486 | if len(args.placeholder_tokens) != len(args.train_data_template): |
| 487 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 487 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
| 488 | else: | ||
| 489 | if isinstance(args.train_data_template, list): | ||
| 490 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | ||
| 488 | 491 | ||
| 489 | if isinstance(args.collection, str): | 492 | if isinstance(args.collection, str): |
| 490 | args.collection = [args.collection] | 493 | args.collection = [args.collection] |
| @@ -503,7 +506,7 @@ def main(): | |||
| 503 | 506 | ||
| 504 | global_step_offset = args.global_step | 507 | global_step_offset = args.global_step |
| 505 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 508 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 506 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 509 | output_dir = Path(args.output_dir)/slugify(args.project)/now |
| 507 | output_dir.mkdir(parents=True, exist_ok=True) | 510 | output_dir.mkdir(parents=True, exist_ok=True) |
| 508 | 511 | ||
| 509 | accelerator = Accelerator( | 512 | accelerator = Accelerator( |
| @@ -519,7 +522,7 @@ def main(): | |||
| 519 | elif args.mixed_precision == "bf16": | 522 | elif args.mixed_precision == "bf16": |
| 520 | weight_dtype = torch.bfloat16 | 523 | weight_dtype = torch.bfloat16 |
| 521 | 524 | ||
| 522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 525 | logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) |
| 523 | 526 | ||
| 524 | if args.seed is None: | 527 | if args.seed is None: |
| 525 | args.seed = torch.random.seed() >> 32 | 528 | args.seed = torch.random.seed() >> 32 |
| @@ -570,7 +573,7 @@ def main(): | |||
| 570 | else: | 573 | else: |
| 571 | optimizer_class = torch.optim.AdamW | 574 | optimizer_class = torch.optim.AdamW |
| 572 | 575 | ||
| 573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 576 | checkpoint_output_dir = output_dir/"checkpoints" |
| 574 | 577 | ||
| 575 | trainer = partial( | 578 | trainer = partial( |
| 576 | train, | 579 | train, |
| @@ -611,11 +614,11 @@ def main(): | |||
| 611 | return | 614 | return |
| 612 | 615 | ||
| 613 | if len(placeholder_tokens) == 1: | 616 | if len(placeholder_tokens) == 1: |
| 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 617 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" |
| 615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | 618 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" |
| 616 | else: | 619 | else: |
| 617 | sample_output_dir = output_dir.joinpath("samples") | 620 | sample_output_dir = output_dir/"samples" |
| 618 | metrics_output_file = output_dir.joinpath(f"lr.png") | 621 | metrics_output_file = output_dir/f"lr.png" |
| 619 | 622 | ||
| 620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 623 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 621 | tokenizer=tokenizer, | 624 | tokenizer=tokenizer, |
