From 26ece1a796c7ef87ed96f5b38fab80d0ae958b9a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 13 Dec 2022 10:45:53 +0100 Subject: Fixed sample/checkpoint frequency --- dreambooth.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 9a6f70a..31416e9 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -625,8 +625,6 @@ def main(): vae.requires_grad_(False) if len(args.placeholder_token) != 0: - print(f"Adding text embeddings: {args.placeholder_token}") - # Convert the initializer_token, placeholder_token to ids initializer_token_ids = torch.stack([ torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) @@ -1114,7 +1112,7 @@ def main(): max_acc_val = avg_acc_val.avg.item() if accelerator.is_main_process: - if epoch % args.sample_frequency == 0: + if (epoch + 1) % args.sample_frequency == 0: checkpointer.save_samples(global_step, args.sample_steps) # Create the pipeline using using the trained modules and save it. -- cgit v1.2.3-54-g00ecf