diff options
| -rw-r--r-- | data/csv.py | 22 | ||||
| -rw-r--r-- | train_dreambooth.py | 27 | ||||
| -rw-r--r-- | train_ti.py | 16 |
3 files changed, 61 insertions, 4 deletions
diff --git a/data/csv.py b/data/csv.py index 7527b7d..55a1988 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -44,18 +44,25 @@ def generate_buckets( | |||
| 44 | items: list[str], | 44 | items: list[str], |
| 45 | base_size: int, | 45 | base_size: int, |
| 46 | step_size: int = 64, | 46 | step_size: int = 64, |
| 47 | max_pixels: Optional[int] = None, | ||
| 47 | num_buckets: int = 4, | 48 | num_buckets: int = 4, |
| 48 | progressive_buckets: bool = False, | 49 | progressive_buckets: bool = False, |
| 49 | return_tensor: bool = True | 50 | return_tensor: bool = True |
| 50 | ): | 51 | ): |
| 52 | if max_pixels is None: | ||
| 53 | max_pixels = (base_size + step_size) ** 2 | ||
| 54 | |||
| 55 | max_pixels = max(max_pixels, base_size * base_size) | ||
| 56 | |||
| 51 | bucket_items: list[int] = [] | 57 | bucket_items: list[int] = [] |
| 52 | bucket_assignments: list[int] = [] | 58 | bucket_assignments: list[int] = [] |
| 53 | buckets = [1.0] | 59 | buckets = [1.0] |
| 54 | 60 | ||
| 55 | for i in range(1, num_buckets + 1): | 61 | for i in range(1, num_buckets + 1): |
| 56 | s = base_size + i * step_size | 62 | long_side = base_size + i * step_size |
| 57 | buckets.append(s / base_size) | 63 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) |
| 58 | buckets.append(base_size / s) | 64 | buckets.append(long_side / short_side) |
| 65 | buckets.append(short_side / long_side) | ||
| 59 | 66 | ||
| 60 | buckets = torch.tensor(buckets) | 67 | buckets = torch.tensor(buckets) |
| 61 | bucket_indices = torch.arange(len(buckets)) | 68 | bucket_indices = torch.arange(len(buckets)) |
| @@ -110,6 +117,8 @@ class VlpnDataModule(): | |||
| 110 | num_class_images: int = 1, | 117 | num_class_images: int = 1, |
| 111 | size: int = 768, | 118 | size: int = 768, |
| 112 | num_buckets: int = 0, | 119 | num_buckets: int = 0, |
| 120 | bucket_step_size: int = 64, | ||
| 121 | max_pixels_per_bucket: Optional[int] = None, | ||
| 113 | progressive_buckets: bool = False, | 122 | progressive_buckets: bool = False, |
| 114 | dropout: float = 0, | 123 | dropout: float = 0, |
| 115 | interpolation: str = "bicubic", | 124 | interpolation: str = "bicubic", |
| @@ -135,6 +144,8 @@ class VlpnDataModule(): | |||
| 135 | self.prompt_processor = prompt_processor | 144 | self.prompt_processor = prompt_processor |
| 136 | self.size = size | 145 | self.size = size |
| 137 | self.num_buckets = num_buckets | 146 | self.num_buckets = num_buckets |
| 147 | self.bucket_step_size = bucket_step_size | ||
| 148 | self.max_pixels_per_bucket = max_pixels_per_bucket | ||
| 138 | self.progressive_buckets = progressive_buckets | 149 | self.progressive_buckets = progressive_buckets |
| 139 | self.dropout = dropout | 150 | self.dropout = dropout |
| 140 | self.template_key = template_key | 151 | self.template_key = template_key |
| @@ -223,6 +234,7 @@ class VlpnDataModule(): | |||
| 223 | train_dataset = VlpnDataset( | 234 | train_dataset = VlpnDataset( |
| 224 | self.data_train, self.prompt_processor, | 235 | self.data_train, self.prompt_processor, |
| 225 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 236 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 237 | bucket_step_size=self.bucket_step_size, max_pixels_per_bucket=self.max_pixels_per_bucket, | ||
| 226 | batch_size=self.batch_size, generator=generator, | 238 | batch_size=self.batch_size, generator=generator, |
| 227 | size=self.size, interpolation=self.interpolation, | 239 | size=self.size, interpolation=self.interpolation, |
| 228 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 240 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
| @@ -251,6 +263,8 @@ class VlpnDataset(IterableDataset): | |||
| 251 | items: list[VlpnDataItem], | 263 | items: list[VlpnDataItem], |
| 252 | prompt_processor: PromptProcessor, | 264 | prompt_processor: PromptProcessor, |
| 253 | num_buckets: int = 1, | 265 | num_buckets: int = 1, |
| 266 | bucket_step_size: int = 64, | ||
| 267 | max_pixels_per_bucket: Optional[int] = None, | ||
| 254 | progressive_buckets: bool = False, | 268 | progressive_buckets: bool = False, |
| 255 | batch_size: int = 1, | 269 | batch_size: int = 1, |
| 256 | num_class_images: int = 0, | 270 | num_class_images: int = 0, |
| @@ -274,7 +288,9 @@ class VlpnDataset(IterableDataset): | |||
| 274 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 288 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
| 275 | [item.instance_image_path for item in items], | 289 | [item.instance_image_path for item in items], |
| 276 | base_size=size, | 290 | base_size=size, |
| 291 | step_size=bucket_step_size, | ||
| 277 | num_buckets=num_buckets, | 292 | num_buckets=num_buckets, |
| 293 | max_pixels=max_pixels_per_bucket, | ||
| 278 | progressive_buckets=progressive_buckets, | 294 | progressive_buckets=progressive_buckets, |
| 279 | ) | 295 | ) |
| 280 | 296 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 79eede6..d396249 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -104,6 +104,29 @@ def parse_args(): | |||
| 104 | help="Number of epochs the text encoder will be trained." | 104 | help="Number of epochs the text encoder will be trained." |
| 105 | ) | 105 | ) |
| 106 | parser.add_argument( | 106 | parser.add_argument( |
| 107 | "--num_buckets", | ||
| 108 | type=int, | ||
| 109 | default=4, | ||
| 110 | help="Number of aspect ratio buckets in either direction.", | ||
| 111 | ) | ||
| 112 | parser.add_argument( | ||
| 113 | "--progressive_buckets", | ||
| 114 | action="store_true", | ||
| 115 | help="Include images in smaller buckets as well.", | ||
| 116 | ) | ||
| 117 | parser.add_argument( | ||
| 118 | "--bucket_step_size", | ||
| 119 | type=int, | ||
| 120 | default=64, | ||
| 121 | help="Step size between buckets.", | ||
| 122 | ) | ||
| 123 | parser.add_argument( | ||
| 124 | "--bucket_max_pixels", | ||
| 125 | type=int, | ||
| 126 | default=None, | ||
| 127 | help="Maximum pixels per bucket.", | ||
| 128 | ) | ||
| 129 | parser.add_argument( | ||
| 107 | "--tag_dropout", | 130 | "--tag_dropout", |
| 108 | type=float, | 131 | type=float, |
| 109 | default=0.1, | 132 | default=0.1, |
| @@ -734,6 +757,10 @@ def main(): | |||
| 734 | class_subdir=args.class_image_dir, | 757 | class_subdir=args.class_image_dir, |
| 735 | num_class_images=args.num_class_images, | 758 | num_class_images=args.num_class_images, |
| 736 | size=args.resolution, | 759 | size=args.resolution, |
| 760 | num_buckets=args.num_buckets, | ||
| 761 | progressive_buckets=args.progressive_buckets, | ||
| 762 | bucket_step_size=args.bucket_step_size, | ||
| 763 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 737 | dropout=args.tag_dropout, | 764 | dropout=args.tag_dropout, |
| 738 | template_key=args.train_data_template, | 765 | template_key=args.train_data_template, |
| 739 | valid_set_size=args.valid_set_size, | 766 | valid_set_size=args.valid_set_size, |
diff --git a/train_ti.py b/train_ti.py index 323ef10..eb0b8b6 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -143,7 +143,7 @@ def parse_args(): | |||
| 143 | "--num_buckets", | 143 | "--num_buckets", |
| 144 | type=int, | 144 | type=int, |
| 145 | default=4, | 145 | default=4, |
| 146 | help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", | 146 | help="Number of aspect ratio buckets in either direction.", |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--progressive_buckets", | 149 | "--progressive_buckets", |
| @@ -151,6 +151,18 @@ def parse_args(): | |||
| 151 | help="Include images in smaller buckets as well.", | 151 | help="Include images in smaller buckets as well.", |
| 152 | ) | 152 | ) |
| 153 | parser.add_argument( | 153 | parser.add_argument( |
| 154 | "--bucket_step_size", | ||
| 155 | type=int, | ||
| 156 | default=64, | ||
| 157 | help="Step size between buckets.", | ||
| 158 | ) | ||
| 159 | parser.add_argument( | ||
| 160 | "--bucket_max_pixels", | ||
| 161 | type=int, | ||
| 162 | default=None, | ||
| 163 | help="Maximum pixels per bucket.", | ||
| 164 | ) | ||
| 165 | parser.add_argument( | ||
| 154 | "--tag_dropout", | 166 | "--tag_dropout", |
| 155 | type=float, | 167 | type=float, |
| 156 | default=0.1, | 168 | default=0.1, |
| @@ -718,6 +730,8 @@ def main(): | |||
| 718 | size=args.resolution, | 730 | size=args.resolution, |
| 719 | num_buckets=args.num_buckets, | 731 | num_buckets=args.num_buckets, |
| 720 | progressive_buckets=args.progressive_buckets, | 732 | progressive_buckets=args.progressive_buckets, |
| 733 | bucket_step_size=args.bucket_step_size, | ||
| 734 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 721 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, |
| 722 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, |
| 723 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
