From 26ece1a796c7ef87ed96f5b38fab80d0ae958b9a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 10:45:53 +0100 Subject: Fixed sample/checkpoint frequency --- dreambooth.py | 4 +--- textual_inversion.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 9a6f70a..31416e9 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -625,8 +625,6 @@ def main(): vae.requires_grad_(False) if len(args.placeholder_token) != 0: - print(f"Adding text embeddings: {args.placeholder_token}") - # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -1114,7 +1112,7 @@ def main(): max_acc_val = avg_acc_val.avg.item() if accelerator.is_main_process: - if epoch % args.sample_frequency == 0: + if (epoch + 1) % args.sample_frequency == 0: checkpointer.save_samples(global_step, args.sample_steps) # Create the pipeline using using the trained modules and save it. diff --git a/textual_inversion.py b/textual_inversion.py index 11babd8..19b8993 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -302,15 +302,11 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." + help="Path to a JSON configuration file containing arguments for invoking this script." ) args = parser.parse_args() - if args.resume_from is not None: - with open(f"{args.resume_from}/resume.json", 'rt') as f: - args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f)["args"])) - elif args.config is not None: + if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) @@ -550,8 +546,6 @@ def main(): unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() - print(f"Adding text embeddings: {args.placeholder_token}") - # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -563,13 +557,17 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + print(f"Token ID mappings:") + for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): + print(f"- {token_id} {token}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - if args.resume_from: + if args.resume_from is not None: resumepath = Path(args.resume_from).joinpath("checkpoints") for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): @@ -983,13 +981,13 @@ def main(): checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss - if epoch % args.checkpoint_frequency == 0: + if (epoch + 1) % args.checkpoint_frequency == 0: checkpointer.checkpoint(global_step + global_step_offset, "training") save_args(basepath, args, { "global_step": global_step + global_step_offset }) - if epoch % args.sample_frequency == 0: + if (epoch + 1) % args.sample_frequency == 0: checkpointer.save_samples( global_step + global_step_offset, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) -- cgit v1.2.3-70-g09d2