summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 1 insertions, 5 deletions
diff --git a/train_ti.py b/train_ti.py
index adba8d4..e696577 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -409,7 +409,7 @@ def parse_args():
409 ) 409 )
410 parser.add_argument( 410 parser.add_argument(
411 "--emb_decay_start", 411 "--emb_decay_start",
412 default=1e-4, 412 default=0,
413 type=float, 413 type=float,
414 help="Embedding decay start offset." 414 help="Embedding decay start offset."
415 ) 415 )
@@ -514,8 +514,6 @@ def main():
514 514
515 set_seed(args.seed) 515 set_seed(args.seed)
516 516
517 seed_generator = torch.Generator().manual_seed(args.seed)
518
519 save_args(output_dir, args) 517 save_args(output_dir, args)
520 518
521 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 519 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -595,8 +593,6 @@ def main():
595 print( 593 print(
596 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") 594 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
597 595
598 args.seed = seed_generator.seed()
599
600 datamodule = VlpnDataModule( 596 datamodule = VlpnDataModule(
601 data_file=args.train_data_file, 597 data_file=args.train_data_file,
602 batch_size=args.train_batch_size, 598 batch_size=args.train_batch_size,