diff options
author | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
commit | 6b8a93f46f053668c8023520225a18445d48d8f1 (patch) | |
tree | 463c8835a9a90dd9b5586a13e55d6882caa3103a /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2 textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip |
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py index bbc5524..83ad46d 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -91,6 +91,11 @@ def parse_args(): | |||
91 | action="store_true", | 91 | action="store_true", |
92 | ) | 92 | ) |
93 | parser.add_argument( | 93 | parser.add_argument( |
94 | "--guidance_scale", | ||
95 | type=float, | ||
96 | default=0, | ||
97 | ) | ||
98 | parser.add_argument( | ||
94 | "--num_class_images", | 99 | "--num_class_images", |
95 | type=int, | 100 | type=int, |
96 | default=0, | 101 | default=0, |
@@ -167,7 +172,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 172 | parser.add_argument( |
168 | "--tag_dropout", | 173 | "--tag_dropout", |
169 | type=float, | 174 | type=float, |
170 | default=0.1, | 175 | default=0, |
171 | help="Tag dropout probability.", | 176 | help="Tag dropout probability.", |
172 | ) | 177 | ) |
173 | parser.add_argument( | 178 | parser.add_argument( |
@@ -190,7 +195,7 @@ def parse_args(): | |||
190 | parser.add_argument( | 195 | parser.add_argument( |
191 | "--offset_noise_strength", | 196 | "--offset_noise_strength", |
192 | type=float, | 197 | type=float, |
193 | default=0.15, | 198 | default=0, |
194 | help="Perlin offset noise strength.", | 199 | help="Perlin offset noise strength.", |
195 | ) | 200 | ) |
196 | parser.add_argument( | 201 | parser.add_argument( |
@@ -651,8 +656,8 @@ def main(): | |||
651 | noise_scheduler=noise_scheduler, | 656 | noise_scheduler=noise_scheduler, |
652 | dtype=weight_dtype, | 657 | dtype=weight_dtype, |
653 | seed=args.seed, | 658 | seed=args.seed, |
654 | with_prior_preservation=args.num_class_images != 0, | 659 | guidance_scale=args.guidance_scale, |
655 | prior_loss_weight=args.prior_loss_weight, | 660 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
656 | no_val=args.valid_set_size == 0, | 661 | no_val=args.valid_set_size == 0, |
657 | strategy=textual_inversion_strategy, | 662 | strategy=textual_inversion_strategy, |
658 | num_train_epochs=args.num_train_epochs, | 663 | num_train_epochs=args.num_train_epochs, |
@@ -705,6 +710,7 @@ def main(): | |||
705 | batch_size=args.train_batch_size, | 710 | batch_size=args.train_batch_size, |
706 | tokenizer=tokenizer, | 711 | tokenizer=tokenizer, |
707 | class_subdir=args.class_image_dir, | 712 | class_subdir=args.class_image_dir, |
713 | with_guidance=args.guidance_scale != 0, | ||
708 | num_class_images=args.num_class_images, | 714 | num_class_images=args.num_class_images, |
709 | size=args.resolution, | 715 | size=args.resolution, |
710 | num_buckets=args.num_buckets, | 716 | num_buckets=args.num_buckets, |