From 9bd1f6b84e58cee0fc2d869a8db2c32f7efe488f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:51:02 +0100 Subject: Pad dataset if len(items) < batch_size --- data/csv.py | 3 +++ train_dreambooth.py | 40 ++++++++++++++++++++-------------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/data/csv.py b/data/csv.py index 968af8d..dec66d7 100644 --- a/data/csv.py +++ b/data/csv.py @@ -267,6 +267,9 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) + if (len(items) < self.batch_size): + items = (items * self.batch_size)[:self.batch_size] + num_images = len(items) valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 diff --git a/train_dreambooth.py b/train_dreambooth.py index 4e41f77..a9fbbbd 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -86,12 +86,6 @@ def parse_args(): nargs='*', help="Exclude all items with a listed collection.", ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - default=True, - help="Whether to train the whole text encoder." - ) parser.add_argument( "--train_text_encoder_epochs", default=999999, @@ -193,6 +187,11 @@ def parse_args(): type=int, default=100 ) + parser.add_argument( + "--ti_num_train_epochs", + type=int, + default=10 + ) parser.add_argument( "--max_train_steps", type=int, @@ -221,6 +220,12 @@ def parse_args(): default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--ti_learning_rate", + type=float, + default=1e-2, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -541,15 +546,14 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, ) # Initial TI print("Phase 1: Textual Inversion") - ti_lr = 1e-1 - ti_train_epochs = 10 - cur_dir = output_dir.joinpath("1-ti") cur_dir.mkdir(parents=True, exist_ok=True) @@ -576,7 +580,7 @@ def main(): size=args.resolution, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, - valid_set_size=args.valid_set_size, + valid_set_size=1, valid_set_repeat=args.valid_set_repeat, seed=args.seed, filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), @@ -586,7 +590,7 @@ def main(): optimizer = optimizer_class( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=ti_lr, + lr=args.ti_learning_rate, weight_decay=0.0, ) @@ -595,8 +599,8 @@ def main(): optimizer=optimizer, num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, - train_epochs=ti_train_epochs, - warmup_epochs=ti_train_epochs // 4, + train_epochs=args.ti_num_train_epochs, + warmup_epochs=args.ti_num_train_epochs // 4, ) trainer( @@ -606,18 +610,16 @@ def main(): val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=ti_train_epochs, - sample_frequency=1, + num_train_epochs=args.ti_num_train_epochs, + sample_frequency=2, checkpoint_frequency=9999999, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, output_dir=cur_subdir, placeholder_tokens=[placeholder_token], placeholder_token_ids=placeholder_token_ids, - learning_rate=ti_lr, + learning_rate=args.ti_learning_rate, gradient_checkpointing=args.gradient_checkpointing, use_emb_decay=True, sample_batch_size=args.sample_batch_size, @@ -699,8 +701,6 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, -- cgit v1.2.3-70-g09d2