diff options
| -rw-r--r-- | data/csv.py | 3 | ||||
| -rw-r--r-- | train_dreambooth.py | 40 |
2 files changed, 23 insertions, 20 deletions
diff --git a/data/csv.py b/data/csv.py index 968af8d..dec66d7 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -267,6 +267,9 @@ class VlpnDataModule(): | |||
| 267 | items = self.prepare_items(template, expansions, items) | 267 | items = self.prepare_items(template, expansions, items) |
| 268 | items = self.filter_items(items) | 268 | items = self.filter_items(items) |
| 269 | 269 | ||
| 270 | if (len(items) < self.batch_size): | ||
| 271 | items = (items * self.batch_size)[:self.batch_size] | ||
| 272 | |||
| 270 | num_images = len(items) | 273 | num_images = len(items) |
| 271 | 274 | ||
| 272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 275 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4e41f77..a9fbbbd 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -87,12 +87,6 @@ def parse_args(): | |||
| 87 | help="Exclude all items with a listed collection.", | 87 | help="Exclude all items with a listed collection.", |
| 88 | ) | 88 | ) |
| 89 | parser.add_argument( | 89 | parser.add_argument( |
| 90 | "--train_text_encoder", | ||
| 91 | action="store_true", | ||
| 92 | default=True, | ||
| 93 | help="Whether to train the whole text encoder." | ||
| 94 | ) | ||
| 95 | parser.add_argument( | ||
| 96 | "--train_text_encoder_epochs", | 90 | "--train_text_encoder_epochs", |
| 97 | default=999999, | 91 | default=999999, |
| 98 | help="Number of epochs the text encoder will be trained." | 92 | help="Number of epochs the text encoder will be trained." |
| @@ -194,6 +188,11 @@ def parse_args(): | |||
| 194 | default=100 | 188 | default=100 |
| 195 | ) | 189 | ) |
| 196 | parser.add_argument( | 190 | parser.add_argument( |
| 191 | "--ti_num_train_epochs", | ||
| 192 | type=int, | ||
| 193 | default=10 | ||
| 194 | ) | ||
| 195 | parser.add_argument( | ||
| 197 | "--max_train_steps", | 196 | "--max_train_steps", |
| 198 | type=int, | 197 | type=int, |
| 199 | default=None, | 198 | default=None, |
| @@ -222,6 +221,12 @@ def parse_args(): | |||
| 222 | help="Initial learning rate (after the potential warmup period) to use.", | 221 | help="Initial learning rate (after the potential warmup period) to use.", |
| 223 | ) | 222 | ) |
| 224 | parser.add_argument( | 223 | parser.add_argument( |
| 224 | "--ti_learning_rate", | ||
| 225 | type=float, | ||
| 226 | default=1e-2, | ||
| 227 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 228 | ) | ||
| 229 | parser.add_argument( | ||
| 225 | "--scale_lr", | 230 | "--scale_lr", |
| 226 | action="store_true", | 231 | action="store_true", |
| 227 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 232 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| @@ -541,15 +546,14 @@ def main(): | |||
| 541 | noise_scheduler=noise_scheduler, | 546 | noise_scheduler=noise_scheduler, |
| 542 | dtype=weight_dtype, | 547 | dtype=weight_dtype, |
| 543 | seed=args.seed, | 548 | seed=args.seed, |
| 549 | with_prior_preservation=args.num_class_images != 0, | ||
| 550 | prior_loss_weight=args.prior_loss_weight, | ||
| 544 | ) | 551 | ) |
| 545 | 552 | ||
| 546 | # Initial TI | 553 | # Initial TI |
| 547 | 554 | ||
| 548 | print("Phase 1: Textual Inversion") | 555 | print("Phase 1: Textual Inversion") |
| 549 | 556 | ||
| 550 | ti_lr = 1e-1 | ||
| 551 | ti_train_epochs = 10 | ||
| 552 | |||
| 553 | cur_dir = output_dir.joinpath("1-ti") | 557 | cur_dir = output_dir.joinpath("1-ti") |
| 554 | cur_dir.mkdir(parents=True, exist_ok=True) | 558 | cur_dir.mkdir(parents=True, exist_ok=True) |
| 555 | 559 | ||
| @@ -576,7 +580,7 @@ def main(): | |||
| 576 | size=args.resolution, | 580 | size=args.resolution, |
| 577 | shuffle=not args.no_tag_shuffle, | 581 | shuffle=not args.no_tag_shuffle, |
| 578 | template_key=args.train_data_template, | 582 | template_key=args.train_data_template, |
| 579 | valid_set_size=args.valid_set_size, | 583 | valid_set_size=1, |
| 580 | valid_set_repeat=args.valid_set_repeat, | 584 | valid_set_repeat=args.valid_set_repeat, |
| 581 | seed=args.seed, | 585 | seed=args.seed, |
| 582 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | 586 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
| @@ -586,7 +590,7 @@ def main(): | |||
| 586 | 590 | ||
| 587 | optimizer = optimizer_class( | 591 | optimizer = optimizer_class( |
| 588 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 592 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), |
| 589 | lr=ti_lr, | 593 | lr=args.ti_learning_rate, |
| 590 | weight_decay=0.0, | 594 | weight_decay=0.0, |
| 591 | ) | 595 | ) |
| 592 | 596 | ||
| @@ -595,8 +599,8 @@ def main(): | |||
| 595 | optimizer=optimizer, | 599 | optimizer=optimizer, |
| 596 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 600 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 597 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 601 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 598 | train_epochs=ti_train_epochs, | 602 | train_epochs=args.ti_num_train_epochs, |
| 599 | warmup_epochs=ti_train_epochs // 4, | 603 | warmup_epochs=args.ti_num_train_epochs // 4, |
| 600 | ) | 604 | ) |
| 601 | 605 | ||
| 602 | trainer( | 606 | trainer( |
| @@ -606,18 +610,16 @@ def main(): | |||
| 606 | val_dataloader=datamodule.val_dataloader, | 610 | val_dataloader=datamodule.val_dataloader, |
| 607 | optimizer=optimizer, | 611 | optimizer=optimizer, |
| 608 | lr_scheduler=lr_scheduler, | 612 | lr_scheduler=lr_scheduler, |
| 609 | num_train_epochs=ti_train_epochs, | 613 | num_train_epochs=args.ti_num_train_epochs, |
| 610 | sample_frequency=1, | 614 | sample_frequency=2, |
| 611 | checkpoint_frequency=9999999, | 615 | checkpoint_frequency=9999999, |
| 612 | with_prior_preservation=args.num_class_images != 0, | ||
| 613 | prior_loss_weight=args.prior_loss_weight, | ||
| 614 | # -- | 616 | # -- |
| 615 | tokenizer=tokenizer, | 617 | tokenizer=tokenizer, |
| 616 | sample_scheduler=sample_scheduler, | 618 | sample_scheduler=sample_scheduler, |
| 617 | output_dir=cur_subdir, | 619 | output_dir=cur_subdir, |
| 618 | placeholder_tokens=[placeholder_token], | 620 | placeholder_tokens=[placeholder_token], |
| 619 | placeholder_token_ids=placeholder_token_ids, | 621 | placeholder_token_ids=placeholder_token_ids, |
| 620 | learning_rate=ti_lr, | 622 | learning_rate=args.ti_learning_rate, |
| 621 | gradient_checkpointing=args.gradient_checkpointing, | 623 | gradient_checkpointing=args.gradient_checkpointing, |
| 622 | use_emb_decay=True, | 624 | use_emb_decay=True, |
| 623 | sample_batch_size=args.sample_batch_size, | 625 | sample_batch_size=args.sample_batch_size, |
| @@ -699,8 +701,6 @@ def main(): | |||
| 699 | lr_scheduler=lr_scheduler, | 701 | lr_scheduler=lr_scheduler, |
| 700 | num_train_epochs=args.num_train_epochs, | 702 | num_train_epochs=args.num_train_epochs, |
| 701 | sample_frequency=args.sample_frequency, | 703 | sample_frequency=args.sample_frequency, |
| 702 | with_prior_preservation=args.num_class_images != 0, | ||
| 703 | prior_loss_weight=args.prior_loss_weight, | ||
| 704 | # -- | 704 | # -- |
| 705 | tokenizer=tokenizer, | 705 | tokenizer=tokenizer, |
| 706 | sample_scheduler=sample_scheduler, | 706 | sample_scheduler=sample_scheduler, |
