diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9a6f70a..31416e9 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -625,8 +625,6 @@ def main(): | |||
| 625 | vae.requires_grad_(False) | 625 | vae.requires_grad_(False) |
| 626 | 626 | ||
| 627 | if len(args.placeholder_token) != 0: | 627 | if len(args.placeholder_token) != 0: |
| 628 | print(f"Adding text embeddings: {args.placeholder_token}") | ||
| 629 | |||
| 630 | # Convert the initializer_token, placeholder_token to ids | 628 | # Convert the initializer_token, placeholder_token to ids |
| 631 | initializer_token_ids = torch.stack([ | 629 | initializer_token_ids = torch.stack([ |
| 632 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 630 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
| @@ -1114,7 +1112,7 @@ def main(): | |||
| 1114 | max_acc_val = avg_acc_val.avg.item() | 1112 | max_acc_val = avg_acc_val.avg.item() |
| 1115 | 1113 | ||
| 1116 | if accelerator.is_main_process: | 1114 | if accelerator.is_main_process: |
| 1117 | if epoch % args.sample_frequency == 0: | 1115 | if (epoch + 1) % args.sample_frequency == 0: |
| 1118 | checkpointer.save_samples(global_step, args.sample_steps) | 1116 | checkpointer.save_samples(global_step, args.sample_steps) |
| 1119 | 1117 | ||
| 1120 | # Create the pipeline using using the trained modules and save it. | 1118 | # Create the pipeline using using the trained modules and save it. |
