summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-10 08:43:34 +0100
committerVolpeon <git@volpeon.ink>2022-12-10 08:43:34 +0100
commit64c79cc3e7fad49131f90fbb0648b6d5587563e5 (patch)
tree372bb09a8c952bd28a8da069659da26ce2c99894 /dreambooth.py
parentFix sample steps (diff)
downloadtextual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.gz
textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.tar.bz2
textual-inversion-diff-64c79cc3e7fad49131f90fbb0648b6d5587563e5.zip
Various updated; shuffle prompt content during training
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/dreambooth.py b/dreambooth.py
index ec9531e..0044c1e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -1,13 +1,11 @@
1import argparse 1import argparse
2import itertools 2import itertools
3import math 3import math
4import os
5import datetime 4import datetime
6import logging 5import logging
7import json 6import json
8from pathlib import Path 7from pathlib import Path
9 8
10import numpy as np
11import torch 9import torch
12import torch.nn.functional as F 10import torch.nn.functional as F
13import torch.utils.checkpoint 11import torch.utils.checkpoint
@@ -299,7 +297,7 @@ def parse_args():
299 parser.add_argument( 297 parser.add_argument(
300 "--sample_steps", 298 "--sample_steps",
301 type=int, 299 type=int,
302 default=20, 300 default=15,
303 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
304 ) 302 )
305 parser.add_argument( 303 parser.add_argument(
@@ -613,7 +611,7 @@ def main():
613 ) 611 )
614 612
615 # Freeze text_encoder and vae 613 # Freeze text_encoder and vae
616 freeze_params(vae.parameters()) 614 vae.requires_grad_(False)
617 615
618 if len(args.placeholder_token) != 0: 616 if len(args.placeholder_token) != 0:
619 print(f"Adding text embeddings: {args.placeholder_token}") 617 print(f"Adding text embeddings: {args.placeholder_token}")
@@ -629,6 +627,10 @@ def main():
629 627
630 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 628 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
631 629
630 print(f"Token ID mappings:")
631 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
632 print(f"- {token_id} {token}")
633
632 # Resize the token embeddings as we are adding new special tokens to the tokenizer 634 # Resize the token embeddings as we are adding new special tokens to the tokenizer
633 text_encoder.resize_token_embeddings(len(tokenizer)) 635 text_encoder.resize_token_embeddings(len(tokenizer))
634 636