diff options
author | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-25 16:34:48 +0100 |
commit | 6b8a93f46f053668c8023520225a18445d48d8f1 (patch) | |
tree | 463c8835a9a90dd9b5586a13e55d6882caa3103a /train_lora.py | |
parent | Update (diff) | |
download | textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.gz textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.tar.bz2 textual-inversion-diff-6b8a93f46f053668c8023520225a18445d48d8f1.zip |
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index b16a99b..684d0cc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -88,7 +88,7 @@ def parse_args(): | |||
88 | parser.add_argument( | 88 | parser.add_argument( |
89 | "--num_buckets", | 89 | "--num_buckets", |
90 | type=int, | 90 | type=int, |
91 | default=0, | 91 | default=2, |
92 | help="Number of aspect ratio buckets in either direction.", | 92 | help="Number of aspect ratio buckets in either direction.", |
93 | ) | 93 | ) |
94 | parser.add_argument( | 94 | parser.add_argument( |
@@ -111,7 +111,7 @@ def parse_args(): | |||
111 | parser.add_argument( | 111 | parser.add_argument( |
112 | "--tag_dropout", | 112 | "--tag_dropout", |
113 | type=float, | 113 | type=float, |
114 | default=0.1, | 114 | default=0, |
115 | help="Tag dropout probability.", | 115 | help="Tag dropout probability.", |
116 | ) | 116 | ) |
117 | parser.add_argument( | 117 | parser.add_argument( |
@@ -120,6 +120,11 @@ def parse_args(): | |||
120 | help="Shuffle tags.", | 120 | help="Shuffle tags.", |
121 | ) | 121 | ) |
122 | parser.add_argument( | 122 | parser.add_argument( |
123 | "--guidance_scale", | ||
124 | type=float, | ||
125 | default=0, | ||
126 | ) | ||
127 | parser.add_argument( | ||
123 | "--num_class_images", | 128 | "--num_class_images", |
124 | type=int, | 129 | type=int, |
125 | default=0, | 130 | default=0, |
@@ -167,7 +172,7 @@ def parse_args(): | |||
167 | parser.add_argument( | 172 | parser.add_argument( |
168 | "--offset_noise_strength", | 173 | "--offset_noise_strength", |
169 | type=float, | 174 | type=float, |
170 | default=0.15, | 175 | default=0, |
171 | help="Perlin offset noise strength.", | 176 | help="Perlin offset noise strength.", |
172 | ) | 177 | ) |
173 | parser.add_argument( | 178 | parser.add_argument( |
@@ -589,8 +594,8 @@ def main(): | |||
589 | vae=vae, | 594 | vae=vae, |
590 | noise_scheduler=noise_scheduler, | 595 | noise_scheduler=noise_scheduler, |
591 | dtype=weight_dtype, | 596 | dtype=weight_dtype, |
592 | with_prior_preservation=args.num_class_images != 0, | 597 | guidance_scale=args.guidance_scale, |
593 | prior_loss_weight=args.prior_loss_weight, | 598 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
594 | no_val=args.valid_set_size == 0, | 599 | no_val=args.valid_set_size == 0, |
595 | ) | 600 | ) |
596 | 601 | ||
@@ -602,6 +607,7 @@ def main(): | |||
602 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
603 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
604 | class_subdir=args.class_image_dir, | 609 | class_subdir=args.class_image_dir, |
610 | with_guidance=args.guidance_scale != 0, | ||
605 | num_class_images=args.num_class_images, | 611 | num_class_images=args.num_class_images, |
606 | size=args.resolution, | 612 | size=args.resolution, |
607 | num_buckets=args.num_buckets, | 613 | num_buckets=args.num_buckets, |