diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
| commit | 6b8a93f46f053668c8023520225a18445d48d8f1 (patch) | |
| tree | 463c8835a9a90dd9b5586a13e55d6882caa3103a /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2 textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip | |
Update
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index bbc5524..83ad46d 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -91,6 +91,11 @@ def parse_args(): | |||
| 91 | action="store_true", | 91 | action="store_true", |
| 92 | ) | 92 | ) |
| 93 | parser.add_argument( | 93 | parser.add_argument( |
| 94 | "--guidance_scale", | ||
| 95 | type=float, | ||
| 96 | default=0, | ||
| 97 | ) | ||
| 98 | parser.add_argument( | ||
| 94 | "--num_class_images", | 99 | "--num_class_images", |
| 95 | type=int, | 100 | type=int, |
| 96 | default=0, | 101 | default=0, |
| @@ -167,7 +172,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 172 | parser.add_argument( |
| 168 | "--tag_dropout", | 173 | "--tag_dropout", |
| 169 | type=float, | 174 | type=float, |
| 170 | default=0.1, | 175 | default=0, |
| 171 | help="Tag dropout probability.", | 176 | help="Tag dropout probability.", |
| 172 | ) | 177 | ) |
| 173 | parser.add_argument( | 178 | parser.add_argument( |
| @@ -190,7 +195,7 @@ def parse_args(): | |||
| 190 | parser.add_argument( | 195 | parser.add_argument( |
| 191 | "--offset_noise_strength", | 196 | "--offset_noise_strength", |
| 192 | type=float, | 197 | type=float, |
| 193 | default=0.15, | 198 | default=0, |
| 194 | help="Perlin offset noise strength.", | 199 | help="Perlin offset noise strength.", |
| 195 | ) | 200 | ) |
| 196 | parser.add_argument( | 201 | parser.add_argument( |
| @@ -651,8 +656,8 @@ def main(): | |||
| 651 | noise_scheduler=noise_scheduler, | 656 | noise_scheduler=noise_scheduler, |
| 652 | dtype=weight_dtype, | 657 | dtype=weight_dtype, |
| 653 | seed=args.seed, | 658 | seed=args.seed, |
| 654 | with_prior_preservation=args.num_class_images != 0, | 659 | guidance_scale=args.guidance_scale, |
| 655 | prior_loss_weight=args.prior_loss_weight, | 660 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 656 | no_val=args.valid_set_size == 0, | 661 | no_val=args.valid_set_size == 0, |
| 657 | strategy=textual_inversion_strategy, | 662 | strategy=textual_inversion_strategy, |
| 658 | num_train_epochs=args.num_train_epochs, | 663 | num_train_epochs=args.num_train_epochs, |
| @@ -705,6 +710,7 @@ def main(): | |||
| 705 | batch_size=args.train_batch_size, | 710 | batch_size=args.train_batch_size, |
| 706 | tokenizer=tokenizer, | 711 | tokenizer=tokenizer, |
| 707 | class_subdir=args.class_image_dir, | 712 | class_subdir=args.class_image_dir, |
| 713 | with_guidance=args.guidance_scale != 0, | ||
| 708 | num_class_images=args.num_class_images, | 714 | num_class_images=args.num_class_images, |
| 709 | size=args.resolution, | 715 | size=args.resolution, |
| 710 | num_buckets=args.num_buckets, | 716 | num_buckets=args.num_buckets, |
