summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--train_ti.py6
2 files changed, 2 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py
index 6857b6f..85b98f8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -282,7 +282,7 @@ class VlpnDataModule():
282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
283 283
284 if valid_set_size == 0: 284 if valid_set_size == 0:
285 data_train, data_val = items, [] 285 data_train, data_val = items, items[:1]
286 else: 286 else:
287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
288 288
diff --git a/train_ti.py b/train_ti.py
index adba8d4..e696577 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -409,7 +409,7 @@ def parse_args():
409 ) 409 )
410 parser.add_argument( 410 parser.add_argument(
411 "--emb_decay_start", 411 "--emb_decay_start",
412 default=1e-4, 412 default=0,
413 type=float, 413 type=float,
414 help="Embedding decay start offset." 414 help="Embedding decay start offset."
415 ) 415 )
@@ -514,8 +514,6 @@ def main():
514 514
515 set_seed(args.seed) 515 set_seed(args.seed)
516 516
517 seed_generator = torch.Generator().manual_seed(args.seed)
518
519 save_args(output_dir, args) 517 save_args(output_dir, args)
520 518
521 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 519 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -595,8 +593,6 @@ def main():
595 print( 593 print(
596 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") 594 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
597 595
598 args.seed = seed_generator.seed()
599
600 datamodule = VlpnDataModule( 596 datamodule = VlpnDataModule(
601 data_file=args.train_data_file, 597 data_file=args.train_data_file,
602 batch_size=args.train_batch_size, 598 batch_size=args.train_batch_size,