diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-10 08:43:34 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-10 08:43:34 +0100 |
| commit | 64c79cc3e7fad49131f90fbb0648b6d5587563e5 (patch) | |
| tree | 372bb09a8c952bd28a8da069659da26ce2c99894 /dreambooth.py | |
| parent | Fix sample steps (diff) | |
| download | textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.gz textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.bz2 textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.zip | |
Various updated; shuffle prompt content during training
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py index ec9531e..0044c1e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -1,13 +1,11 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import itertools | 2 | import itertools |
| 3 | import math | 3 | import math |
| 4 | import os | ||
| 5 | import datetime | 4 | import datetime |
| 6 | import logging | 5 | import logging |
| 7 | import json | 6 | import json |
| 8 | from pathlib import Path | 7 | from pathlib import Path |
| 9 | 8 | ||
| 10 | import numpy as np | ||
| 11 | import torch | 9 | import torch |
| 12 | import torch.nn.functional as F | 10 | import torch.nn.functional as F |
| 13 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| @@ -299,7 +297,7 @@ def parse_args(): | |||
| 299 | parser.add_argument( | 297 | parser.add_argument( |
| 300 | "--sample_steps", | 298 | "--sample_steps", |
| 301 | type=int, | 299 | type=int, |
| 302 | default=20, | 300 | default=15, |
| 303 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 304 | ) | 302 | ) |
| 305 | parser.add_argument( | 303 | parser.add_argument( |
| @@ -613,7 +611,7 @@ def main(): | |||
| 613 | ) | 611 | ) |
| 614 | 612 | ||
| 615 | # Freeze text_encoder and vae | 613 | # Freeze text_encoder and vae |
| 616 | freeze_params(vae.parameters()) | 614 | vae.requires_grad_(False) |
| 617 | 615 | ||
| 618 | if len(args.placeholder_token) != 0: | 616 | if len(args.placeholder_token) != 0: |
| 619 | print(f"Adding text embeddings: {args.placeholder_token}") | 617 | print(f"Adding text embeddings: {args.placeholder_token}") |
| @@ -629,6 +627,10 @@ def main(): | |||
| 629 | 627 | ||
| 630 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 628 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 631 | 629 | ||
| 630 | print(f"Token ID mappings:") | ||
| 631 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
| 632 | print(f"- {token_id} {token}") | ||
| 633 | |||
| 632 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 634 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 633 | text_encoder.resize_token_embeddings(len(tokenizer)) | 635 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 634 | 636 | ||
