diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1b8a3d2..7a33bca 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -110,7 +110,7 @@ def parse_args(): | |||
110 | parser.add_argument( | 110 | parser.add_argument( |
111 | "--tag_dropout", | 111 | "--tag_dropout", |
112 | type=float, | 112 | type=float, |
113 | default=0.1, | 113 | default=0, |
114 | help="Tag dropout probability.", | 114 | help="Tag dropout probability.", |
115 | ) | 115 | ) |
116 | parser.add_argument( | 116 | parser.add_argument( |
@@ -131,6 +131,11 @@ def parse_args(): | |||
131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
132 | ) | 132 | ) |
133 | parser.add_argument( | 133 | parser.add_argument( |
134 | "--guidance_scale", | ||
135 | type=float, | ||
136 | default=0, | ||
137 | ) | ||
138 | parser.add_argument( | ||
134 | "--num_class_images", | 139 | "--num_class_images", |
135 | type=int, | 140 | type=int, |
136 | default=0, | 141 | default=0, |
@@ -178,7 +183,7 @@ def parse_args(): | |||
178 | parser.add_argument( | 183 | parser.add_argument( |
179 | "--offset_noise_strength", | 184 | "--offset_noise_strength", |
180 | type=float, | 185 | type=float, |
181 | default=0.15, | 186 | default=0, |
182 | help="Perlin offset noise strength.", | 187 | help="Perlin offset noise strength.", |
183 | ) | 188 | ) |
184 | parser.add_argument( | 189 | parser.add_argument( |
@@ -557,8 +562,8 @@ def main(): | |||
557 | vae=vae, | 562 | vae=vae, |
558 | noise_scheduler=noise_scheduler, | 563 | noise_scheduler=noise_scheduler, |
559 | dtype=weight_dtype, | 564 | dtype=weight_dtype, |
560 | with_prior_preservation=args.num_class_images != 0, | 565 | guidance_scale=args.guidance_scale, |
561 | prior_loss_weight=args.prior_loss_weight, | 566 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
562 | no_val=args.valid_set_size == 0, | 567 | no_val=args.valid_set_size == 0, |
563 | ) | 568 | ) |
564 | 569 | ||
@@ -570,6 +575,7 @@ def main(): | |||
570 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
571 | tokenizer=tokenizer, | 576 | tokenizer=tokenizer, |
572 | class_subdir=args.class_image_dir, | 577 | class_subdir=args.class_image_dir, |
578 | with_guidance=args.guidance_scale != 0, | ||
573 | num_class_images=args.num_class_images, | 579 | num_class_images=args.num_class_images, |
574 | size=args.resolution, | 580 | size=args.resolution, |
575 | num_buckets=args.num_buckets, | 581 | num_buckets=args.num_buckets, |