diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 |
| commit | 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 (patch) | |
| tree | 6c1f2243475778bb5e9e1725bf3969a5442393d8 /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.gz textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.bz2 textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.zip | |
Fixes
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index 8c63493..89f4113 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -160,6 +160,12 @@ def parse_args(): | |||
| 160 | help="A collection to filter the dataset.", | 160 | help="A collection to filter the dataset.", |
| 161 | ) | 161 | ) |
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| 163 | "--validation_prompts", | ||
| 164 | type=str, | ||
| 165 | nargs="*", | ||
| 166 | help="Prompts for additional validation images", | ||
| 167 | ) | ||
| 168 | parser.add_argument( | ||
| 163 | "--seed", type=int, default=None, help="A seed for reproducible training." | 169 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 164 | ) | 170 | ) |
| 165 | parser.add_argument( | 171 | parser.add_argument( |
| @@ -456,7 +462,7 @@ def parse_args(): | |||
| 456 | parser.add_argument( | 462 | parser.add_argument( |
| 457 | "--sample_steps", | 463 | "--sample_steps", |
| 458 | type=int, | 464 | type=int, |
| 459 | default=10, | 465 | default=15, |
| 460 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 461 | ) | 467 | ) |
| 462 | parser.add_argument( | 468 | parser.add_argument( |
| @@ -852,11 +858,6 @@ def main(): | |||
| 852 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
| 853 | ) | 859 | ) |
| 854 | 860 | ||
| 855 | optimizer = create_optimizer( | ||
| 856 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 857 | lr=args.learning_rate, | ||
| 858 | ) | ||
| 859 | |||
| 860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 861 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 861 | data_npgenerator = np.random.default_rng(args.seed) | 862 | data_npgenerator = np.random.default_rng(args.seed) |
| 862 | 863 | ||
| @@ -957,6 +958,11 @@ def main(): | |||
| 957 | avg_loss_val = AverageMeter() | 958 | avg_loss_val = AverageMeter() |
| 958 | avg_acc_val = AverageMeter() | 959 | avg_acc_val = AverageMeter() |
| 959 | 960 | ||
| 961 | optimizer = create_optimizer( | ||
| 962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 963 | lr=args.learning_rate, | ||
| 964 | ) | ||
| 965 | |||
| 960 | while True: | 966 | while True: |
| 961 | if len(auto_cycles) != 0: | 967 | if len(auto_cycles) != 0: |
| 962 | response = auto_cycles.pop(0) | 968 | response = auto_cycles.pop(0) |
