diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-09 10:19:37 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-09 10:19:37 +0100 |
| commit | b57ca669a150d9313447612fb8c37668f4f2a80d (patch) | |
| tree | b0ebfedc33c26847838850416b96fd2623cf6ba5 /train_ti.py | |
| parent | No cache after all (diff) | |
| download | textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.gz textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.tar.bz2 textual-inversion-diff-b57ca669a150d9313447612fb8c37668f4f2a80d.zip | |
Add --valid_set_repeat
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index 03f52c4..7784d04 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -381,6 +381,12 @@ def parse_args(): | |||
| 381 | help="Number of images in the validation dataset." | 381 | help="Number of images in the validation dataset." |
| 382 | ) | 382 | ) |
| 383 | parser.add_argument( | 383 | parser.add_argument( |
| 384 | "--valid_set_repeat", | ||
| 385 | type=int, | ||
| 386 | default=None, | ||
| 387 | help="Times the images in the validation dataset are repeated." | ||
| 388 | ) | ||
| 389 | parser.add_argument( | ||
| 384 | "--train_batch_size", | 390 | "--train_batch_size", |
| 385 | type=int, | 391 | type=int, |
| 386 | default=1, | 392 | default=1, |
| @@ -399,6 +405,12 @@ def parse_args(): | |||
| 399 | help="The weight of prior preservation loss." | 405 | help="The weight of prior preservation loss." |
| 400 | ) | 406 | ) |
| 401 | parser.add_argument( | 407 | parser.add_argument( |
| 408 | "--max_grad_norm", | ||
| 409 | default=3.0, | ||
| 410 | type=float, | ||
| 411 | help="Max gradient norm." | ||
| 412 | ) | ||
| 413 | parser.add_argument( | ||
| 402 | "--noise_timesteps", | 414 | "--noise_timesteps", |
| 403 | type=int, | 415 | type=int, |
| 404 | default=1000, | 416 | default=1000, |
| @@ -465,6 +477,9 @@ def parse_args(): | |||
| 465 | if isinstance(args.exclude_collections, str): | 477 | if isinstance(args.exclude_collections, str): |
| 466 | args.exclude_collections = [args.exclude_collections] | 478 | args.exclude_collections = [args.exclude_collections] |
| 467 | 479 | ||
| 480 | if args.valid_set_repeat is None: | ||
| 481 | args.valid_set_repeat = args.train_batch_size | ||
| 482 | |||
| 468 | if args.output_dir is None: | 483 | if args.output_dir is None: |
| 469 | raise ValueError("You must specify --output_dir") | 484 | raise ValueError("You must specify --output_dir") |
| 470 | 485 | ||
| @@ -735,6 +750,7 @@ def main(): | |||
| 735 | dropout=args.tag_dropout, | 750 | dropout=args.tag_dropout, |
| 736 | template_key=args.train_data_template, | 751 | template_key=args.train_data_template, |
| 737 | valid_set_size=args.valid_set_size, | 752 | valid_set_size=args.valid_set_size, |
| 753 | valid_set_repeat=args.valid_set_repeat, | ||
| 738 | num_workers=args.dataloader_num_workers, | 754 | num_workers=args.dataloader_num_workers, |
| 739 | seed=args.seed, | 755 | seed=args.seed, |
| 740 | filter=keyword_filter, | 756 | filter=keyword_filter, |
| @@ -961,6 +977,12 @@ def main(): | |||
| 961 | 977 | ||
| 962 | accelerator.backward(loss) | 978 | accelerator.backward(loss) |
| 963 | 979 | ||
| 980 | if accelerator.sync_gradients: | ||
| 981 | accelerator.clip_grad_norm_( | ||
| 982 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 983 | args.max_grad_norm | ||
| 984 | ) | ||
| 985 | |||
| 964 | optimizer.step() | 986 | optimizer.step() |
| 965 | if not accelerator.optimizer_step_was_skipped: | 987 | if not accelerator.optimizer_step_was_skipped: |
| 966 | lr_scheduler.step() | 988 | lr_scheduler.step() |
