From 96e887da4be2c13f5f58da3359a9ab891c44d050 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 18:45:03 +0100 Subject: If valid set size is 0, re-use one image from train set --- data/csv.py | 2 +- train_ti.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/data/csv.py b/data/csv.py index 6857b6f..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -282,7 +282,7 @@ class VlpnDataModule(): collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: - data_train, data_val = items, [] + data_train, data_val = items, items[:1] else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) diff --git a/train_ti.py b/train_ti.py index adba8d4..e696577 100644 --- a/train_ti.py +++ b/train_ti.py @@ -409,7 +409,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay_start", - default=1e-4, + default=0, type=float, help="Embedding decay start offset." ) @@ -514,8 +514,6 @@ def main(): set_seed(args.seed) - seed_generator = torch.Generator().manual_seed(args.seed) - save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -595,8 +593,6 @@ def main(): print( f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") - args.seed = seed_generator.seed() - datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, -- cgit v1.2.3-70-g09d2