diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 1b8a3d2..7a33bca 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -110,7 +110,7 @@ def parse_args(): | |||
| 110 | parser.add_argument( | 110 | parser.add_argument( |
| 111 | "--tag_dropout", | 111 | "--tag_dropout", |
| 112 | type=float, | 112 | type=float, |
| 113 | default=0.1, | 113 | default=0, |
| 114 | help="Tag dropout probability.", | 114 | help="Tag dropout probability.", |
| 115 | ) | 115 | ) |
| 116 | parser.add_argument( | 116 | parser.add_argument( |
| @@ -131,6 +131,11 @@ def parse_args(): | |||
| 131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 131 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
| 132 | ) | 132 | ) |
| 133 | parser.add_argument( | 133 | parser.add_argument( |
| 134 | "--guidance_scale", | ||
| 135 | type=float, | ||
| 136 | default=0, | ||
| 137 | ) | ||
| 138 | parser.add_argument( | ||
| 134 | "--num_class_images", | 139 | "--num_class_images", |
| 135 | type=int, | 140 | type=int, |
| 136 | default=0, | 141 | default=0, |
| @@ -178,7 +183,7 @@ def parse_args(): | |||
| 178 | parser.add_argument( | 183 | parser.add_argument( |
| 179 | "--offset_noise_strength", | 184 | "--offset_noise_strength", |
| 180 | type=float, | 185 | type=float, |
| 181 | default=0.15, | 186 | default=0, |
| 182 | help="Perlin offset noise strength.", | 187 | help="Perlin offset noise strength.", |
| 183 | ) | 188 | ) |
| 184 | parser.add_argument( | 189 | parser.add_argument( |
| @@ -557,8 +562,8 @@ def main(): | |||
| 557 | vae=vae, | 562 | vae=vae, |
| 558 | noise_scheduler=noise_scheduler, | 563 | noise_scheduler=noise_scheduler, |
| 559 | dtype=weight_dtype, | 564 | dtype=weight_dtype, |
| 560 | with_prior_preservation=args.num_class_images != 0, | 565 | guidance_scale=args.guidance_scale, |
| 561 | prior_loss_weight=args.prior_loss_weight, | 566 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 562 | no_val=args.valid_set_size == 0, | 567 | no_val=args.valid_set_size == 0, |
| 563 | ) | 568 | ) |
| 564 | 569 | ||
| @@ -570,6 +575,7 @@ def main(): | |||
| 570 | batch_size=args.train_batch_size, | 575 | batch_size=args.train_batch_size, |
| 571 | tokenizer=tokenizer, | 576 | tokenizer=tokenizer, |
| 572 | class_subdir=args.class_image_dir, | 577 | class_subdir=args.class_image_dir, |
| 578 | with_guidance=args.guidance_scale != 0, | ||
| 573 | num_class_images=args.num_class_images, | 579 | num_class_images=args.num_class_images, |
| 574 | size=args.resolution, | 580 | size=args.resolution, |
| 575 | num_buckets=args.num_buckets, | 581 | num_buckets=args.num_buckets, |
