diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index ec9531e..0044c1e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -1,13 +1,11 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import itertools | 2 | import itertools |
3 | import math | 3 | import math |
4 | import os | ||
5 | import datetime | 4 | import datetime |
6 | import logging | 5 | import logging |
7 | import json | 6 | import json |
8 | from pathlib import Path | 7 | from pathlib import Path |
9 | 8 | ||
10 | import numpy as np | ||
11 | import torch | 9 | import torch |
12 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
13 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
@@ -299,7 +297,7 @@ def parse_args(): | |||
299 | parser.add_argument( | 297 | parser.add_argument( |
300 | "--sample_steps", | 298 | "--sample_steps", |
301 | type=int, | 299 | type=int, |
302 | default=20, | 300 | default=15, |
303 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
304 | ) | 302 | ) |
305 | parser.add_argument( | 303 | parser.add_argument( |
@@ -613,7 +611,7 @@ def main(): | |||
613 | ) | 611 | ) |
614 | 612 | ||
615 | # Freeze text_encoder and vae | 613 | # Freeze text_encoder and vae |
616 | freeze_params(vae.parameters()) | 614 | vae.requires_grad_(False) |
617 | 615 | ||
618 | if len(args.placeholder_token) != 0: | 616 | if len(args.placeholder_token) != 0: |
619 | print(f"Adding text embeddings: {args.placeholder_token}") | 617 | print(f"Adding text embeddings: {args.placeholder_token}") |
@@ -629,6 +627,10 @@ def main(): | |||
629 | 627 | ||
630 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 628 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
631 | 629 | ||
630 | print(f"Token ID mappings:") | ||
631 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
632 | print(f"- {token_id} {token}") | ||
633 | |||
632 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 634 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
633 | text_encoder.resize_token_embeddings(len(tokenizer)) | 635 | text_encoder.resize_token_embeddings(len(tokenizer)) |
634 | 636 | ||