From 26ece1a796c7ef87ed96f5b38fab80d0ae958b9a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 10:45:53 +0100 Subject: Fixed sample/checkpoint frequency --- textual_inversion.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 11babd8..19b8993 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -302,15 +302,11 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." + help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() - if args.resume_from is not None: - with open(f"{args.resume_from}/resume.json", 'rt') as f: - args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f)["args"])) - elif args.config is not None: + if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) @@ -550,8 +546,6 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - print(f"Adding text embeddings: {args.placeholder_token}") - # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -563,13 +557,17 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + print(f"Token ID mappings:") + for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): + print(f"- {token_id} {token}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - if args.resume_from: + if args.resume_from is not None: resumepath = Path(args.resume_from).joinpath("checkpoints") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): @@ -983,13 +981,13 @@ def main(): checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss - if epoch % args.checkpoint_frequency == 0: + if (epoch + 1) % args.checkpoint_frequency == 0: checkpointer.checkpoint(global_step + global_step_offset, "training") save_args(basepath, args, { "global_step": global_step + global_step_offset }) - if epoch % args.sample_frequency == 0: + if (epoch + 1) % args.sample_frequency == 0: checkpointer.save_samples( global_step + global_step_offset, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) -- cgit v1.2.3-54-g00ecf