diff options
| -rw-r--r-- | data/csv.py | 39 | ||||
| -rw-r--r-- | train_dreambooth.py | 71 | ||||
| -rw-r--r-- | train_ti.py | 17 | ||||
| -rw-r--r-- | training/functional.py | 5 | ||||
| -rw-r--r-- | training/optimization.py | 10 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
6 files changed, 101 insertions, 43 deletions
diff --git a/data/csv.py b/data/csv.py index dec66d7..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -174,7 +174,8 @@ class VlpnDataModule(): | |||
| 174 | interpolation: str = "bicubic", | 174 | interpolation: str = "bicubic", |
| 175 | template_key: str = "template", | 175 | template_key: str = "template", |
| 176 | valid_set_size: Optional[int] = None, | 176 | valid_set_size: Optional[int] = None, |
| 177 | valid_set_repeat: int = 1, | 177 | train_set_pad: Optional[int] = None, |
| 178 | valid_set_pad: Optional[int] = None, | ||
| 178 | seed: Optional[int] = None, | 179 | seed: Optional[int] = None, |
| 179 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 180 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 180 | dtype: torch.dtype = torch.float32, | 181 | dtype: torch.dtype = torch.float32, |
| @@ -202,7 +203,8 @@ class VlpnDataModule(): | |||
| 202 | self.template_key = template_key | 203 | self.template_key = template_key |
| 203 | self.interpolation = interpolation | 204 | self.interpolation = interpolation |
| 204 | self.valid_set_size = valid_set_size | 205 | self.valid_set_size = valid_set_size |
| 205 | self.valid_set_repeat = valid_set_repeat | 206 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size |
| 207 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size | ||
| 206 | self.seed = seed | 208 | self.seed = seed |
| 207 | self.filter = filter | 209 | self.filter = filter |
| 208 | self.batch_size = batch_size | 210 | self.batch_size = batch_size |
| @@ -267,9 +269,6 @@ class VlpnDataModule(): | |||
| 267 | items = self.prepare_items(template, expansions, items) | 269 | items = self.prepare_items(template, expansions, items) |
| 268 | items = self.filter_items(items) | 270 | items = self.filter_items(items) |
| 269 | 271 | ||
| 270 | if (len(items) < self.batch_size): | ||
| 271 | items = (items * self.batch_size)[:self.batch_size] | ||
| 272 | |||
| 273 | num_images = len(items) | 272 | num_images = len(items) |
| 274 | 273 | ||
| 275 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 274 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
| @@ -283,14 +282,17 @@ class VlpnDataModule(): | |||
| 283 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 284 | 283 | ||
| 285 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
| 286 | data_train, data_val = items, [] | 285 | data_train, data_val = items, items[:1] |
| 287 | else: | 286 | else: |
| 288 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
| 289 | 288 | ||
| 290 | self.data_train = self.pad_items(data_train, self.num_class_images) | 289 | data_train = self.pad_items(data_train, self.num_class_images) |
| 290 | |||
| 291 | if len(data_train) < self.train_set_pad: | ||
| 292 | data_train *= math.ceil(self.train_set_pad / len(data_train)) | ||
| 291 | 293 | ||
| 292 | train_dataset = VlpnDataset( | 294 | self.train_dataset = VlpnDataset( |
| 293 | self.data_train, self.tokenizer, | 295 | data_train, self.tokenizer, |
| 294 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 296 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
| 295 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 297 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 296 | batch_size=self.batch_size, generator=generator, | 298 | batch_size=self.batch_size, generator=generator, |
| @@ -299,24 +301,26 @@ class VlpnDataModule(): | |||
| 299 | ) | 301 | ) |
| 300 | 302 | ||
| 301 | self.train_dataloader = DataLoader( | 303 | self.train_dataloader = DataLoader( |
| 302 | train_dataset, | 304 | self.train_dataset, |
| 303 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 305 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 304 | ) | 306 | ) |
| 305 | 307 | ||
| 306 | if valid_set_size != 0: | 308 | if len(data_val) != 0: |
| 307 | self.data_val = self.pad_items(data_val) | 309 | data_val = self.pad_items(data_val) |
| 310 | |||
| 311 | if len(data_val) < self.valid_set_pad: | ||
| 312 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) | ||
| 308 | 313 | ||
| 309 | val_dataset = VlpnDataset( | 314 | self.val_dataset = VlpnDataset( |
| 310 | self.data_val, self.tokenizer, | 315 | data_val, self.tokenizer, |
| 311 | num_buckets=self.num_buckets, progressive_buckets=True, | 316 | num_buckets=self.num_buckets, progressive_buckets=True, |
| 312 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 317 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
| 313 | repeat=self.valid_set_repeat, | ||
| 314 | batch_size=self.batch_size, generator=generator, | 318 | batch_size=self.batch_size, generator=generator, |
| 315 | size=self.size, interpolation=self.interpolation, | 319 | size=self.size, interpolation=self.interpolation, |
| 316 | ) | 320 | ) |
| 317 | 321 | ||
| 318 | self.val_dataloader = DataLoader( | 322 | self.val_dataloader = DataLoader( |
| 319 | val_dataset, | 323 | self.val_dataset, |
| 320 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 324 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 321 | ) | 325 | ) |
| 322 | else: | 326 | else: |
| @@ -332,7 +336,6 @@ class VlpnDataset(IterableDataset): | |||
| 332 | bucket_step_size: int = 64, | 336 | bucket_step_size: int = 64, |
| 333 | bucket_max_pixels: Optional[int] = None, | 337 | bucket_max_pixels: Optional[int] = None, |
| 334 | progressive_buckets: bool = False, | 338 | progressive_buckets: bool = False, |
| 335 | repeat: int = 1, | ||
| 336 | batch_size: int = 1, | 339 | batch_size: int = 1, |
| 337 | num_class_images: int = 0, | 340 | num_class_images: int = 0, |
| 338 | size: int = 768, | 341 | size: int = 768, |
| @@ -341,7 +344,7 @@ class VlpnDataset(IterableDataset): | |||
| 341 | interpolation: str = "bicubic", | 344 | interpolation: str = "bicubic", |
| 342 | generator: Optional[torch.Generator] = None, | 345 | generator: Optional[torch.Generator] = None, |
| 343 | ): | 346 | ): |
| 344 | self.items = items * repeat | 347 | self.items = items |
| 345 | self.batch_size = batch_size | 348 | self.batch_size = batch_size |
| 346 | 349 | ||
| 347 | self.tokenizer = tokenizer | 350 | self.tokenizer = tokenizer |
diff --git a/train_dreambooth.py b/train_dreambooth.py index a9fbbbd..1dc41b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -55,6 +55,18 @@ def parse_args(): | |||
| 55 | default="template", | 55 | default="template", |
| 56 | ) | 56 | ) |
| 57 | parser.add_argument( | 57 | parser.add_argument( |
| 58 | "--train_set_pad", | ||
| 59 | type=int, | ||
| 60 | default=None, | ||
| 61 | help="The number to fill train dataset items up to." | ||
| 62 | ) | ||
| 63 | parser.add_argument( | ||
| 64 | "--valid_set_pad", | ||
| 65 | type=int, | ||
| 66 | default=None, | ||
| 67 | help="The number to fill validation dataset items up to." | ||
| 68 | ) | ||
| 69 | parser.add_argument( | ||
| 58 | "--project", | 70 | "--project", |
| 59 | type=str, | 71 | type=str, |
| 60 | default=None, | 72 | default=None, |
| @@ -188,11 +200,23 @@ def parse_args(): | |||
| 188 | default=100 | 200 | default=100 |
| 189 | ) | 201 | ) |
| 190 | parser.add_argument( | 202 | parser.add_argument( |
| 203 | "--ti_data_template", | ||
| 204 | type=str, | ||
| 205 | nargs='*', | ||
| 206 | default=[], | ||
| 207 | ) | ||
| 208 | parser.add_argument( | ||
| 191 | "--ti_num_train_epochs", | 209 | "--ti_num_train_epochs", |
| 192 | type=int, | 210 | type=int, |
| 193 | default=10 | 211 | default=10 |
| 194 | ) | 212 | ) |
| 195 | parser.add_argument( | 213 | parser.add_argument( |
| 214 | "--ti_batch_size", | ||
| 215 | type=int, | ||
| 216 | default=1, | ||
| 217 | help="Batch size (per device) for the training dataloader." | ||
| 218 | ) | ||
| 219 | parser.add_argument( | ||
| 196 | "--max_train_steps", | 220 | "--max_train_steps", |
| 197 | type=int, | 221 | type=int, |
| 198 | default=None, | 222 | default=None, |
| @@ -458,6 +482,12 @@ def parse_args(): | |||
| 458 | if len(args.placeholder_tokens) != len(args.num_vectors): | 482 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 459 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 460 | 484 | ||
| 485 | if isinstance(args.ti_data_template, str): | ||
| 486 | args.ti_data_template = [args.ti_data_template] | ||
| 487 | |||
| 488 | if len(args.ti_data_template) == 0: | ||
| 489 | raise ValueError("You must specify --ti_data_template") | ||
| 490 | |||
| 461 | if isinstance(args.collection, str): | 491 | if isinstance(args.collection, str): |
| 462 | args.collection = [args.collection] | 492 | args.collection = [args.collection] |
| 463 | 493 | ||
| @@ -491,6 +521,8 @@ def main(): | |||
| 491 | 521 | ||
| 492 | set_seed(args.seed) | 522 | set_seed(args.seed) |
| 493 | 523 | ||
| 524 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
| 525 | |||
| 494 | save_args(output_dir, args) | 526 | save_args(output_dir, args) |
| 495 | 527 | ||
| 496 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 528 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| @@ -512,6 +544,8 @@ def main(): | |||
| 512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 513 | raise ValueError("--embeddings_dir must point to an existing directory") | 545 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 514 | 546 | ||
| 547 | embeddings.persist() | ||
| 548 | |||
| 515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 549 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 516 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 550 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 517 | 551 | ||
| @@ -545,7 +579,6 @@ def main(): | |||
| 545 | vae=vae, | 579 | vae=vae, |
| 546 | noise_scheduler=noise_scheduler, | 580 | noise_scheduler=noise_scheduler, |
| 547 | dtype=weight_dtype, | 581 | dtype=weight_dtype, |
| 548 | seed=args.seed, | ||
| 549 | with_prior_preservation=args.num_class_images != 0, | 582 | with_prior_preservation=args.num_class_images != 0, |
| 550 | prior_loss_weight=args.prior_loss_weight, | 583 | prior_loss_weight=args.prior_loss_weight, |
| 551 | ) | 584 | ) |
| @@ -557,13 +590,17 @@ def main(): | |||
| 557 | cur_dir = output_dir.joinpath("1-ti") | 590 | cur_dir = output_dir.joinpath("1-ti") |
| 558 | cur_dir.mkdir(parents=True, exist_ok=True) | 591 | cur_dir.mkdir(parents=True, exist_ok=True) |
| 559 | 592 | ||
| 560 | for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): | 593 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
| 561 | print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") | 594 | range(len(args.placeholder_tokens)), |
| 562 | 595 | args.placeholder_tokens, | |
| 596 | args.initializer_tokens, | ||
| 597 | args.num_vectors, | ||
| 598 | args.ti_data_template | ||
| 599 | ): | ||
| 563 | cur_subdir = cur_dir.joinpath(placeholder_token) | 600 | cur_subdir = cur_dir.joinpath(placeholder_token) |
| 564 | cur_subdir.mkdir(parents=True, exist_ok=True) | 601 | cur_subdir.mkdir(parents=True, exist_ok=True) |
| 565 | 602 | ||
| 566 | placeholder_token_ids, _ = add_placeholder_tokens( | 603 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 567 | tokenizer=tokenizer, | 604 | tokenizer=tokenizer, |
| 568 | embeddings=embeddings, | 605 | embeddings=embeddings, |
| 569 | placeholder_tokens=[placeholder_token], | 606 | placeholder_tokens=[placeholder_token], |
| @@ -571,17 +608,23 @@ def main(): | |||
| 571 | num_vectors=[num_vectors] | 608 | num_vectors=[num_vectors] |
| 572 | ) | 609 | ) |
| 573 | 610 | ||
| 611 | print( | ||
| 612 | f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | ||
| 613 | |||
| 614 | args.seed = seed_generator.seed() | ||
| 615 | |||
| 574 | datamodule = VlpnDataModule( | 616 | datamodule = VlpnDataModule( |
| 575 | data_file=args.train_data_file, | 617 | data_file=args.train_data_file, |
| 576 | batch_size=args.train_batch_size, | 618 | batch_size=args.ti_batch_size, |
| 577 | tokenizer=tokenizer, | 619 | tokenizer=tokenizer, |
| 578 | class_subdir=args.class_image_dir, | 620 | class_subdir=args.class_image_dir, |
| 579 | num_class_images=args.num_class_images, | 621 | num_class_images=args.num_class_images, |
| 580 | size=args.resolution, | 622 | size=args.resolution, |
| 581 | shuffle=not args.no_tag_shuffle, | 623 | shuffle=not args.no_tag_shuffle, |
| 582 | template_key=args.train_data_template, | 624 | template_key=data_template, |
| 583 | valid_set_size=1, | 625 | valid_set_size=1, |
| 584 | valid_set_repeat=args.valid_set_repeat, | 626 | train_set_pad=args.train_set_pad, |
| 627 | valid_set_pad=args.valid_set_pad, | ||
| 585 | seed=args.seed, | 628 | seed=args.seed, |
| 586 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | 629 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
| 587 | dtype=weight_dtype | 630 | dtype=weight_dtype |
| @@ -591,7 +634,9 @@ def main(): | |||
| 591 | optimizer = optimizer_class( | 634 | optimizer = optimizer_class( |
| 592 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 635 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 593 | lr=args.ti_learning_rate, | 636 | lr=args.ti_learning_rate, |
| 637 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 594 | weight_decay=0.0, | 638 | weight_decay=0.0, |
| 639 | eps=args.adam_epsilon, | ||
| 595 | ) | 640 | ) |
| 596 | 641 | ||
| 597 | lr_scheduler = get_scheduler( | 642 | lr_scheduler = get_scheduler( |
| @@ -600,7 +645,6 @@ def main(): | |||
| 600 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 645 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 601 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 646 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 602 | train_epochs=args.ti_num_train_epochs, | 647 | train_epochs=args.ti_num_train_epochs, |
| 603 | warmup_epochs=args.ti_num_train_epochs // 4, | ||
| 604 | ) | 648 | ) |
| 605 | 649 | ||
| 606 | trainer( | 650 | trainer( |
| @@ -608,10 +652,11 @@ def main(): | |||
| 608 | project="textual_inversion", | 652 | project="textual_inversion", |
| 609 | train_dataloader=datamodule.train_dataloader, | 653 | train_dataloader=datamodule.train_dataloader, |
| 610 | val_dataloader=datamodule.val_dataloader, | 654 | val_dataloader=datamodule.val_dataloader, |
| 655 | seed=args.seed, | ||
| 611 | optimizer=optimizer, | 656 | optimizer=optimizer, |
| 612 | lr_scheduler=lr_scheduler, | 657 | lr_scheduler=lr_scheduler, |
| 613 | num_train_epochs=args.ti_num_train_epochs, | 658 | num_train_epochs=args.ti_num_train_epochs, |
| 614 | sample_frequency=2, | 659 | sample_frequency=args.ti_num_train_epochs // 5, |
| 615 | checkpoint_frequency=9999999, | 660 | checkpoint_frequency=9999999, |
| 616 | # -- | 661 | # -- |
| 617 | tokenizer=tokenizer, | 662 | tokenizer=tokenizer, |
| @@ -637,7 +682,7 @@ def main(): | |||
| 637 | cur_dir = output_dir.joinpath("2-db") | 682 | cur_dir = output_dir.joinpath("2-db") |
| 638 | cur_dir.mkdir(parents=True, exist_ok=True) | 683 | cur_dir.mkdir(parents=True, exist_ok=True) |
| 639 | 684 | ||
| 640 | args.seed = (args.seed + 28635) >> 32 | 685 | args.seed = seed_generator.seed() |
| 641 | 686 | ||
| 642 | datamodule = VlpnDataModule( | 687 | datamodule = VlpnDataModule( |
| 643 | data_file=args.train_data_file, | 688 | data_file=args.train_data_file, |
| @@ -654,7 +699,8 @@ def main(): | |||
| 654 | shuffle=not args.no_tag_shuffle, | 699 | shuffle=not args.no_tag_shuffle, |
| 655 | template_key=args.train_data_template, | 700 | template_key=args.train_data_template, |
| 656 | valid_set_size=args.valid_set_size, | 701 | valid_set_size=args.valid_set_size, |
| 657 | valid_set_repeat=args.valid_set_repeat, | 702 | train_set_pad=args.train_set_pad, |
| 703 | valid_set_pad=args.valid_set_pad, | ||
| 658 | seed=args.seed, | 704 | seed=args.seed, |
| 659 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 705 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 660 | dtype=weight_dtype | 706 | dtype=weight_dtype |
| @@ -697,6 +743,7 @@ def main(): | |||
| 697 | project="dreambooth", | 743 | project="dreambooth", |
| 698 | train_dataloader=datamodule.train_dataloader, | 744 | train_dataloader=datamodule.train_dataloader, |
| 699 | val_dataloader=datamodule.val_dataloader, | 745 | val_dataloader=datamodule.val_dataloader, |
| 746 | seed=args.seed, | ||
| 700 | optimizer=optimizer, | 747 | optimizer=optimizer, |
| 701 | lr_scheduler=lr_scheduler, | 748 | lr_scheduler=lr_scheduler, |
| 702 | num_train_epochs=args.num_train_epochs, | 749 | num_train_epochs=args.num_train_epochs, |
diff --git a/train_ti.py b/train_ti.py index a894ee7..7aecdef 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -360,10 +360,16 @@ def parse_args(): | |||
| 360 | help="Number of images in the validation dataset." | 360 | help="Number of images in the validation dataset." |
| 361 | ) | 361 | ) |
| 362 | parser.add_argument( | 362 | parser.add_argument( |
| 363 | "--valid_set_repeat", | 363 | "--train_set_pad", |
| 364 | type=int, | 364 | type=int, |
| 365 | default=1, | 365 | default=None, |
| 366 | help="Times the images in the validation dataset are repeated." | 366 | help="The number to fill train dataset items up to." |
| 367 | ) | ||
| 368 | parser.add_argument( | ||
| 369 | "--valid_set_pad", | ||
| 370 | type=int, | ||
| 371 | default=None, | ||
| 372 | help="The number to fill validation dataset items up to." | ||
| 367 | ) | 373 | ) |
| 368 | parser.add_argument( | 374 | parser.add_argument( |
| 369 | "--train_batch_size", | 375 | "--train_batch_size", |
| @@ -575,7 +581,8 @@ def main(): | |||
| 575 | shuffle=not args.no_tag_shuffle, | 581 | shuffle=not args.no_tag_shuffle, |
| 576 | template_key=args.train_data_template, | 582 | template_key=args.train_data_template, |
| 577 | valid_set_size=args.valid_set_size, | 583 | valid_set_size=args.valid_set_size, |
| 578 | valid_set_repeat=args.valid_set_repeat, | 584 | train_set_pad=args.train_set_pad, |
| 585 | valid_set_pad=args.valid_set_pad, | ||
| 579 | seed=args.seed, | 586 | seed=args.seed, |
| 580 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 587 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), |
| 581 | dtype=weight_dtype | 588 | dtype=weight_dtype |
| @@ -590,7 +597,7 @@ def main(): | |||
| 590 | unet, | 597 | unet, |
| 591 | tokenizer, | 598 | tokenizer, |
| 592 | sample_scheduler, | 599 | sample_scheduler, |
| 593 | datamodule.data_train, | 600 | datamodule.train_dataset, |
| 594 | args.sample_batch_size, | 601 | args.sample_batch_size, |
| 595 | args.sample_image_size, | 602 | args.sample_image_size, |
| 596 | args.sample_steps | 603 | args.sample_steps |
diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol | |||
| 17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | 18 | from PIL import Image |
| 19 | 19 | ||
| 20 | from data.csv import VlpnDataset | ||
| 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 22 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
| @@ -175,12 +176,12 @@ def generate_class_images( | |||
| 175 | unet: UNet2DConditionModel, | 176 | unet: UNet2DConditionModel, |
| 176 | tokenizer: MultiCLIPTokenizer, | 177 | tokenizer: MultiCLIPTokenizer, |
| 177 | sample_scheduler: DPMSolverMultistepScheduler, | 178 | sample_scheduler: DPMSolverMultistepScheduler, |
| 178 | data_train, | 179 | train_dataset: VlpnDataset, |
| 179 | sample_batch_size: int, | 180 | sample_batch_size: int, |
| 180 | sample_image_size: int, | 181 | sample_image_size: int, |
| 181 | sample_steps: int | 182 | sample_steps: int |
| 182 | ): | 183 | ): |
| 183 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 184 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] |
| 184 | 185 | ||
| 185 | if len(missing_data) == 0: | 186 | if len(missing_data) == 0: |
| 186 | return | 187 | return |
diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -49,8 +49,8 @@ def get_one_cycle_schedule( | |||
| 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
| 50 | warmup_exp: int = 1, | 50 | warmup_exp: int = 1, |
| 51 | annealing_exp: int = 1, | 51 | annealing_exp: int = 1, |
| 52 | min_lr: int = 0.04, | 52 | min_lr: float = 0.04, |
| 53 | mid_point: int = 0.3, | 53 | mid_point: float = 0.3, |
| 54 | last_epoch: int = -1 | 54 | last_epoch: int = -1 |
| 55 | ): | 55 | ): |
| 56 | if warmup == "linear": | 56 | if warmup == "linear": |
| @@ -91,10 +91,10 @@ def get_scheduler( | |||
| 91 | id: str, | 91 | id: str, |
| 92 | optimizer: torch.optim.Optimizer, | 92 | optimizer: torch.optim.Optimizer, |
| 93 | num_training_steps_per_epoch: int, | 93 | num_training_steps_per_epoch: int, |
| 94 | gradient_accumulation_steps: int, | 94 | gradient_accumulation_steps: int = 1, |
| 95 | min_lr: float = 0.04, | 95 | min_lr: float = 0.04, |
| 96 | warmup_func: str = "cos", | 96 | warmup_func: Literal["cos", "linear"] = "cos", |
| 97 | annealing_func: str = "cos", | 97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
| 98 | warmup_exp: int = 1, | 98 | warmup_exp: int = 1, |
| 99 | annealing_exp: int = 1, | 99 | annealing_exp: int = 1, |
| 100 | cycles: int = 1, | 100 | cycles: int = 1, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 568f9eb..9d39e15 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -36,7 +36,7 @@ def textual_inversion_strategy( | |||
| 36 | use_emb_decay: bool = False, | 36 | use_emb_decay: bool = False, |
| 37 | emb_decay_target: float = 0.4, | 37 | emb_decay_target: float = 0.4, |
| 38 | emb_decay_factor: float = 1, | 38 | emb_decay_factor: float = 1, |
| 39 | emb_decay_start: float = 1e-4, | 39 | emb_decay_start: float = 0, |
| 40 | use_ema: bool = False, | 40 | use_ema: bool = False, |
| 41 | ema_inv_gamma: float = 1.0, | 41 | ema_inv_gamma: float = 1.0, |
| 42 | ema_power: int = 1, | 42 | ema_power: int = 1, |
