diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/train_lora.py b/train_lora.py index b16a99b..684d0cc 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -88,7 +88,7 @@ def parse_args(): | |||
| 88 | parser.add_argument( | 88 | parser.add_argument( |
| 89 | "--num_buckets", | 89 | "--num_buckets", |
| 90 | type=int, | 90 | type=int, |
| 91 | default=0, | 91 | default=2, |
| 92 | help="Number of aspect ratio buckets in either direction.", | 92 | help="Number of aspect ratio buckets in either direction.", |
| 93 | ) | 93 | ) |
| 94 | parser.add_argument( | 94 | parser.add_argument( |
| @@ -111,7 +111,7 @@ def parse_args(): | |||
| 111 | parser.add_argument( | 111 | parser.add_argument( |
| 112 | "--tag_dropout", | 112 | "--tag_dropout", |
| 113 | type=float, | 113 | type=float, |
| 114 | default=0.1, | 114 | default=0, |
| 115 | help="Tag dropout probability.", | 115 | help="Tag dropout probability.", |
| 116 | ) | 116 | ) |
| 117 | parser.add_argument( | 117 | parser.add_argument( |
| @@ -120,6 +120,11 @@ def parse_args(): | |||
| 120 | help="Shuffle tags.", | 120 | help="Shuffle tags.", |
| 121 | ) | 121 | ) |
| 122 | parser.add_argument( | 122 | parser.add_argument( |
| 123 | "--guidance_scale", | ||
| 124 | type=float, | ||
| 125 | default=0, | ||
| 126 | ) | ||
| 127 | parser.add_argument( | ||
| 123 | "--num_class_images", | 128 | "--num_class_images", |
| 124 | type=int, | 129 | type=int, |
| 125 | default=0, | 130 | default=0, |
| @@ -167,7 +172,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 172 | parser.add_argument( |
| 168 | "--offset_noise_strength", | 173 | "--offset_noise_strength", |
| 169 | type=float, | 174 | type=float, |
| 170 | default=0.15, | 175 | default=0, |
| 171 | help="Perlin offset noise strength.", | 176 | help="Perlin offset noise strength.", |
| 172 | ) | 177 | ) |
| 173 | parser.add_argument( | 178 | parser.add_argument( |
| @@ -589,8 +594,8 @@ def main(): | |||
| 589 | vae=vae, | 594 | vae=vae, |
| 590 | noise_scheduler=noise_scheduler, | 595 | noise_scheduler=noise_scheduler, |
| 591 | dtype=weight_dtype, | 596 | dtype=weight_dtype, |
| 592 | with_prior_preservation=args.num_class_images != 0, | 597 | guidance_scale=args.guidance_scale, |
| 593 | prior_loss_weight=args.prior_loss_weight, | 598 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 594 | no_val=args.valid_set_size == 0, | 599 | no_val=args.valid_set_size == 0, |
| 595 | ) | 600 | ) |
| 596 | 601 | ||
| @@ -602,6 +607,7 @@ def main(): | |||
| 602 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
| 603 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
| 604 | class_subdir=args.class_image_dir, | 609 | class_subdir=args.class_image_dir, |
| 610 | with_guidance=args.guidance_scale != 0, | ||
| 605 | num_class_images=args.num_class_images, | 611 | num_class_images=args.num_class_images, |
| 606 | size=args.resolution, | 612 | size=args.resolution, |
| 607 | num_buckets=args.num_buckets, | 613 | num_buckets=args.num_buckets, |
