summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py
index b16a99b..684d0cc 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -88,7 +88,7 @@ def parse_args():
88 parser.add_argument( 88 parser.add_argument(
89 "--num_buckets", 89 "--num_buckets",
90 type=int, 90 type=int,
91 default=0, 91 default=2,
92 help="Number of aspect ratio buckets in either direction.", 92 help="Number of aspect ratio buckets in either direction.",
93 ) 93 )
94 parser.add_argument( 94 parser.add_argument(
@@ -111,7 +111,7 @@ def parse_args():
111 parser.add_argument( 111 parser.add_argument(
112 "--tag_dropout", 112 "--tag_dropout",
113 type=float, 113 type=float,
114 default=0.1, 114 default=0,
115 help="Tag dropout probability.", 115 help="Tag dropout probability.",
116 ) 116 )
117 parser.add_argument( 117 parser.add_argument(
@@ -120,6 +120,11 @@ def parse_args():
120 help="Shuffle tags.", 120 help="Shuffle tags.",
121 ) 121 )
122 parser.add_argument( 122 parser.add_argument(
123 "--guidance_scale",
124 type=float,
125 default=0,
126 )
127 parser.add_argument(
123 "--num_class_images", 128 "--num_class_images",
124 type=int, 129 type=int,
125 default=0, 130 default=0,
@@ -167,7 +172,7 @@ def parse_args():
167 parser.add_argument( 172 parser.add_argument(
168 "--offset_noise_strength", 173 "--offset_noise_strength",
169 type=float, 174 type=float,
170 default=0.15, 175 default=0,
171 help="Perlin offset noise strength.", 176 help="Perlin offset noise strength.",
172 ) 177 )
173 parser.add_argument( 178 parser.add_argument(
@@ -589,8 +594,8 @@ def main():
589 vae=vae, 594 vae=vae,
590 noise_scheduler=noise_scheduler, 595 noise_scheduler=noise_scheduler,
591 dtype=weight_dtype, 596 dtype=weight_dtype,
592 with_prior_preservation=args.num_class_images != 0, 597 guidance_scale=args.guidance_scale,
593 prior_loss_weight=args.prior_loss_weight, 598 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
594 no_val=args.valid_set_size == 0, 599 no_val=args.valid_set_size == 0,
595 ) 600 )
596 601
@@ -602,6 +607,7 @@ def main():
602 batch_size=args.train_batch_size, 607 batch_size=args.train_batch_size,
603 tokenizer=tokenizer, 608 tokenizer=tokenizer,
604 class_subdir=args.class_image_dir, 609 class_subdir=args.class_image_dir,
610 with_guidance=args.guidance_scale != 0,
605 num_class_images=args.num_class_images, 611 num_class_images=args.num_class_images,
606 size=args.resolution, 612 size=args.resolution,
607 num_buckets=args.num_buckets, 613 num_buckets=args.num_buckets,