summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
commit300deaa789a0321f32d5e7f04d9860eaa258110e (patch)
tree892e89753e5c4d86d787131595751bc03c610be8 /textual_inversion.py
parentDefault sample steps 30 -> 40 (diff)
downloadtextual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.gz
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.bz2
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.zip
Add Textual Inversion with class dataset (a la Dreambooth)
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index d842288..7919ebd 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -230,7 +230,7 @@ def parse_args():
230 parser.add_argument( 230 parser.add_argument(
231 "--sample_steps", 231 "--sample_steps",
232 type=int, 232 type=int,
233 default=40, 233 default=30,
234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
235 ) 235 )
236 parser.add_argument( 236 parser.add_argument(
@@ -329,7 +329,7 @@ class Checkpointer:
329 self.placeholder_token_id = placeholder_token_id 329 self.placeholder_token_id = placeholder_token_id
330 self.output_dir = output_dir 330 self.output_dir = output_dir
331 self.sample_image_size = sample_image_size 331 self.sample_image_size = sample_image_size
332 self.seed = seed 332 self.seed = seed or torch.random.seed()
333 self.sample_batches = sample_batches 333 self.sample_batches = sample_batches
334 self.sample_batch_size = sample_batch_size 334 self.sample_batch_size = sample_batch_size
335 335
@@ -481,9 +481,9 @@ def main():
481 # Convert the initializer_token, placeholder_token to ids 481 # Convert the initializer_token, placeholder_token to ids
482 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 482 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
483 # Check if initializer_token is a single token or a sequence of tokens 483 # Check if initializer_token is a single token or a sequence of tokens
484 if args.vectors_per_token % len(initializer_token_ids) != 0: 484 if len(initializer_token_ids) > 1:
485 raise ValueError( 485 raise ValueError(
486 f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") 486 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
487 487
488 initializer_token_ids = torch.tensor(initializer_token_ids) 488 initializer_token_ids = torch.tensor(initializer_token_ids)
489 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 489 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
@@ -590,7 +590,7 @@ def main():
590 sample_image_size=args.sample_image_size, 590 sample_image_size=args.sample_image_size,
591 sample_batch_size=args.sample_batch_size, 591 sample_batch_size=args.sample_batch_size,
592 sample_batches=args.sample_batches, 592 sample_batches=args.sample_batches,
593 seed=args.seed or torch.random.seed() 593 seed=args.seed
594 ) 594 )
595 595
596 # Scheduler and math around the number of training steps. 596 # Scheduler and math around the number of training steps.
@@ -620,8 +620,7 @@ def main():
620 unet.eval() 620 unet.eval()
621 621
622 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 622 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
623 num_update_steps_per_epoch = math.ceil( 623 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
624 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
625 if overrode_max_train_steps: 624 if overrode_max_train_steps:
626 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 625 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
627 626