summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1b8a3d2..7a33bca 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -110,7 +110,7 @@ def parse_args():
110 parser.add_argument( 110 parser.add_argument(
111 "--tag_dropout", 111 "--tag_dropout",
112 type=float, 112 type=float,
113 default=0.1, 113 default=0,
114 help="Tag dropout probability.", 114 help="Tag dropout probability.",
115 ) 115 )
116 parser.add_argument( 116 parser.add_argument(
@@ -131,6 +131,11 @@ def parse_args():
131 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 131 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
132 ) 132 )
133 parser.add_argument( 133 parser.add_argument(
134 "--guidance_scale",
135 type=float,
136 default=0,
137 )
138 parser.add_argument(
134 "--num_class_images", 139 "--num_class_images",
135 type=int, 140 type=int,
136 default=0, 141 default=0,
@@ -178,7 +183,7 @@ def parse_args():
178 parser.add_argument( 183 parser.add_argument(
179 "--offset_noise_strength", 184 "--offset_noise_strength",
180 type=float, 185 type=float,
181 default=0.15, 186 default=0,
182 help="Perlin offset noise strength.", 187 help="Perlin offset noise strength.",
183 ) 188 )
184 parser.add_argument( 189 parser.add_argument(
@@ -557,8 +562,8 @@ def main():
557 vae=vae, 562 vae=vae,
558 noise_scheduler=noise_scheduler, 563 noise_scheduler=noise_scheduler,
559 dtype=weight_dtype, 564 dtype=weight_dtype,
560 with_prior_preservation=args.num_class_images != 0, 565 guidance_scale=args.guidance_scale,
561 prior_loss_weight=args.prior_loss_weight, 566 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
562 no_val=args.valid_set_size == 0, 567 no_val=args.valid_set_size == 0,
563 ) 568 )
564 569
@@ -570,6 +575,7 @@ def main():
570 batch_size=args.train_batch_size, 575 batch_size=args.train_batch_size,
571 tokenizer=tokenizer, 576 tokenizer=tokenizer,
572 class_subdir=args.class_image_dir, 577 class_subdir=args.class_image_dir,
578 with_guidance=args.guidance_scale != 0,
573 num_class_images=args.num_class_images, 579 num_class_images=args.num_class_images,
574 size=args.resolution, 580 size=args.resolution,
575 num_buckets=args.num_buckets, 581 num_buckets=args.num_buckets,