summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 10:45:53 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 10:45:53 +0100
commit26ece1a796c7ef87ed96f5b38fab80d0ae958b9a (patch)
treed733df7000d5259c1a3ad19abd6dec982d65974d
parentAdd support for resume in Textual Inversion (diff)
downloadtextual-inversion-diff-26ece1a796c7ef87ed96f5b38fab80d0ae958b9a.tar.gz
textual-inversion-diff-26ece1a796c7ef87ed96f5b38fab80d0ae958b9a.tar.bz2
textual-inversion-diff-26ece1a796c7ef87ed96f5b38fab80d0ae958b9a.zip
Fixed sample/checkpoint frequency
-rw-r--r--dreambooth.py4
-rw-r--r--textual_inversion.py20
2 files changed, 10 insertions, 14 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9a6f70a..31416e9 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -625,8 +625,6 @@ def main():
625 vae.requires_grad_(False) 625 vae.requires_grad_(False)
626 626
627 if len(args.placeholder_token) != 0: 627 if len(args.placeholder_token) != 0:
628 print(f"Adding text embeddings: {args.placeholder_token}")
629
630 # Convert the initializer_token, placeholder_token to ids 628 # Convert the initializer_token, placeholder_token to ids
631 initializer_token_ids = torch.stack([ 629 initializer_token_ids = torch.stack([
632 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 630 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -1114,7 +1112,7 @@ def main():
1114 max_acc_val = avg_acc_val.avg.item() 1112 max_acc_val = avg_acc_val.avg.item()
1115 1113
1116 if accelerator.is_main_process: 1114 if accelerator.is_main_process:
1117 if epoch % args.sample_frequency == 0: 1115 if (epoch + 1) % args.sample_frequency == 0:
1118 checkpointer.save_samples(global_step, args.sample_steps) 1116 checkpointer.save_samples(global_step, args.sample_steps)
1119 1117
1120 # Create the pipeline using using the trained modules and save it. 1118 # Create the pipeline using using the trained modules and save it.
diff --git a/textual_inversion.py b/textual_inversion.py
index 11babd8..19b8993 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -302,15 +302,11 @@ def parse_args():
302 "--config", 302 "--config",
303 type=str, 303 type=str,
304 default=None, 304 default=None,
305 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." 305 help="Path to a JSON configuration file containing arguments for invoking this script."
306 ) 306 )
307 307
308 args = parser.parse_args() 308 args = parser.parse_args()
309 if args.resume_from is not None: 309 if args.config is not None:
310 with open(f"{args.resume_from}/resume.json", 'rt') as f:
311 args = parser.parse_args(
312 namespace=argparse.Namespace(**json.load(f)["args"]))
313 elif args.config is not None:
314 with open(args.config, 'rt') as f: 310 with open(args.config, 'rt') as f:
315 args = parser.parse_args( 311 args = parser.parse_args(
316 namespace=argparse.Namespace(**json.load(f)["args"])) 312 namespace=argparse.Namespace(**json.load(f)["args"]))
@@ -550,8 +546,6 @@ def main():
550 unet.enable_gradient_checkpointing() 546 unet.enable_gradient_checkpointing()
551 text_encoder.gradient_checkpointing_enable() 547 text_encoder.gradient_checkpointing_enable()
552 548
553 print(f"Adding text embeddings: {args.placeholder_token}")
554
555 # Convert the initializer_token, placeholder_token to ids 549 # Convert the initializer_token, placeholder_token to ids
556 initializer_token_ids = torch.stack([ 550 initializer_token_ids = torch.stack([
557 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 551 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -563,13 +557,17 @@ def main():
563 557
564 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 558 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
565 559
560 print(f"Token ID mappings:")
561 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
562 print(f"- {token_id} {token}")
563
566 # Resize the token embeddings as we are adding new special tokens to the tokenizer 564 # Resize the token embeddings as we are adding new special tokens to the tokenizer
567 text_encoder.resize_token_embeddings(len(tokenizer)) 565 text_encoder.resize_token_embeddings(len(tokenizer))
568 566
569 # Initialise the newly added placeholder token with the embeddings of the initializer token 567 # Initialise the newly added placeholder token with the embeddings of the initializer token
570 token_embeds = text_encoder.get_input_embeddings().weight.data 568 token_embeds = text_encoder.get_input_embeddings().weight.data
571 569
572 if args.resume_from: 570 if args.resume_from is not None:
573 resumepath = Path(args.resume_from).joinpath("checkpoints") 571 resumepath = Path(args.resume_from).joinpath("checkpoints")
574 572
575 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 573 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
@@ -983,13 +981,13 @@ def main():
983 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 981 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
984 min_val_loss = val_loss 982 min_val_loss = val_loss
985 983
986 if epoch % args.checkpoint_frequency == 0: 984 if (epoch + 1) % args.checkpoint_frequency == 0:
987 checkpointer.checkpoint(global_step + global_step_offset, "training") 985 checkpointer.checkpoint(global_step + global_step_offset, "training")
988 save_args(basepath, args, { 986 save_args(basepath, args, {
989 "global_step": global_step + global_step_offset 987 "global_step": global_step + global_step_offset
990 }) 988 })
991 989
992 if epoch % args.sample_frequency == 0: 990 if (epoch + 1) % args.sample_frequency == 0:
993 checkpointer.save_samples( 991 checkpointer.save_samples(
994 global_step + global_step_offset, 992 global_step + global_step_offset,
995 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 993 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)