diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9a6f70a..31416e9 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -625,8 +625,6 @@ def main(): | |||
625 | vae.requires_grad_(False) | 625 | vae.requires_grad_(False) |
626 | 626 | ||
627 | if len(args.placeholder_token) != 0: | 627 | if len(args.placeholder_token) != 0: |
628 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
629 | |||
630 | # Convert the initializer_token, placeholder_token to ids | 628 | # Convert the initializer_token, placeholder_token to ids |
631 | initializer_token_ids = torch.stack([ | 629 | initializer_token_ids = torch.stack([ |
632 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 630 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
@@ -1114,7 +1112,7 @@ def main(): | |||
1114 | max_acc_val = avg_acc_val.avg.item() | 1112 | max_acc_val = avg_acc_val.avg.item() |
1115 | 1113 | ||
1116 | if accelerator.is_main_process: | 1114 | if accelerator.is_main_process: |
1117 | if epoch % args.sample_frequency == 0: | 1115 | if (epoch + 1) % args.sample_frequency == 0: |
1118 | checkpointer.save_samples(global_step, args.sample_steps) | 1116 | checkpointer.save_samples(global_step, args.sample_steps) |
1119 | 1117 | ||
1120 | # Create the pipeline using using the trained modules and save it. | 1118 | # Create the pipeline using using the trained modules and save it. |