summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 17:12:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 17:12:06 +0100
commitb1d7b2962e6454f8e72bd64efe08dc80a1f2d3aa (patch)
tree8e93d7c884b2ba0d86748ad200173fcdb39df5ef
parentSave args before training, too (diff)
downloadtextual-inversion-diff-b1d7b2962e6454f8e72bd64efe08dc80a1f2d3aa.tar.gz
textual-inversion-diff-b1d7b2962e6454f8e72bd64efe08dc80a1f2d3aa.tar.bz2
textual-inversion-diff-b1d7b2962e6454f8e72bd64efe08dc80a1f2d3aa.zip
Fix
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_ti.py6
2 files changed, 2 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index cd0bf67..05f6cb5 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -955,10 +955,6 @@ def main():
955 ) 955 )
956 global_progress_bar.set_description("Total progress") 956 global_progress_bar.set_description("Total progress")
957 957
958 save_args(basepath, args, {
959 "global_step": global_step + global_step_offset
960 })
961
962 try: 958 try:
963 for epoch in range(num_epochs): 959 for epoch in range(num_epochs):
964 if accelerator.is_main_process: 960 if accelerator.is_main_process:
diff --git a/train_ti.py b/train_ti.py
index 6c74854..97dde1e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -519,6 +519,8 @@ def main():
519 args.seed = args.seed or (torch.random.seed() >> 32) 519 args.seed = args.seed or (torch.random.seed() >> 32)
520 set_seed(args.seed) 520 set_seed(args.seed)
521 521
522 save_args(basepath, args)
523
522 # Load the tokenizer and add the placeholder token as a additional special token 524 # Load the tokenizer and add the placeholder token as a additional special token
523 if args.tokenizer_name: 525 if args.tokenizer_name:
524 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) 526 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
@@ -903,10 +905,6 @@ def main():
903 ) 905 )
904 global_progress_bar.set_description("Total progress") 906 global_progress_bar.set_description("Total progress")
905 907
906 save_args(basepath, args, {
907 "global_step": global_step + global_step_offset
908 })
909
910 try: 908 try:
911 for epoch in range(num_epochs): 909 for epoch in range(num_epochs):
912 if accelerator.is_main_process: 910 if accelerator.is_main_process: