summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-25 16:34:48 +0100
committerVolpeon <git@volpeon.ink>2023-03-25 16:34:48 +0100
commit6b8a93f46f053668c8023520225a18445d48d8f1 (patch)
tree463c8835a9a90dd9b5586a13e55d6882caa3103a /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz
textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2
textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip
Update
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/train_ti.py b/train_ti.py
index bbc5524..83ad46d 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -91,6 +91,11 @@ def parse_args():
91 action="store_true", 91 action="store_true",
92 ) 92 )
93 parser.add_argument( 93 parser.add_argument(
94 "--guidance_scale",
95 type=float,
96 default=0,
97 )
98 parser.add_argument(
94 "--num_class_images", 99 "--num_class_images",
95 type=int, 100 type=int,
96 default=0, 101 default=0,
@@ -167,7 +172,7 @@ def parse_args():
167 parser.add_argument( 172 parser.add_argument(
168 "--tag_dropout", 173 "--tag_dropout",
169 type=float, 174 type=float,
170 default=0.1, 175 default=0,
171 help="Tag dropout probability.", 176 help="Tag dropout probability.",
172 ) 177 )
173 parser.add_argument( 178 parser.add_argument(
@@ -190,7 +195,7 @@ def parse_args():
190 parser.add_argument( 195 parser.add_argument(
191 "--offset_noise_strength", 196 "--offset_noise_strength",
192 type=float, 197 type=float,
193 default=0.15, 198 default=0,
194 help="Perlin offset noise strength.", 199 help="Perlin offset noise strength.",
195 ) 200 )
196 parser.add_argument( 201 parser.add_argument(
@@ -651,8 +656,8 @@ def main():
651 noise_scheduler=noise_scheduler, 656 noise_scheduler=noise_scheduler,
652 dtype=weight_dtype, 657 dtype=weight_dtype,
653 seed=args.seed, 658 seed=args.seed,
654 with_prior_preservation=args.num_class_images != 0, 659 guidance_scale=args.guidance_scale,
655 prior_loss_weight=args.prior_loss_weight, 660 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
656 no_val=args.valid_set_size == 0, 661 no_val=args.valid_set_size == 0,
657 strategy=textual_inversion_strategy, 662 strategy=textual_inversion_strategy,
658 num_train_epochs=args.num_train_epochs, 663 num_train_epochs=args.num_train_epochs,
@@ -705,6 +710,7 @@ def main():
705 batch_size=args.train_batch_size, 710 batch_size=args.train_batch_size,
706 tokenizer=tokenizer, 711 tokenizer=tokenizer,
707 class_subdir=args.class_image_dir, 712 class_subdir=args.class_image_dir,
713 with_guidance=args.guidance_scale != 0,
708 num_class_images=args.num_class_images, 714 num_class_images=args.num_class_images,
709 size=args.resolution, 715 size=args.resolution,
710 num_buckets=args.num_buckets, 716 num_buckets=args.num_buckets,