diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index d842288..7919ebd 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -230,7 +230,7 @@ def parse_args(): | |||
230 | parser.add_argument( | 230 | parser.add_argument( |
231 | "--sample_steps", | 231 | "--sample_steps", |
232 | type=int, | 232 | type=int, |
233 | default=40, | 233 | default=30, |
234 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 234 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
235 | ) | 235 | ) |
236 | parser.add_argument( | 236 | parser.add_argument( |
@@ -329,7 +329,7 @@ class Checkpointer: | |||
329 | self.placeholder_token_id = placeholder_token_id | 329 | self.placeholder_token_id = placeholder_token_id |
330 | self.output_dir = output_dir | 330 | self.output_dir = output_dir |
331 | self.sample_image_size = sample_image_size | 331 | self.sample_image_size = sample_image_size |
332 | self.seed = seed | 332 | self.seed = seed or torch.random.seed() |
333 | self.sample_batches = sample_batches | 333 | self.sample_batches = sample_batches |
334 | self.sample_batch_size = sample_batch_size | 334 | self.sample_batch_size = sample_batch_size |
335 | 335 | ||
@@ -481,9 +481,9 @@ def main(): | |||
481 | # Convert the initializer_token, placeholder_token to ids | 481 | # Convert the initializer_token, placeholder_token to ids |
482 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 482 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) |
483 | # Check if initializer_token is a single token or a sequence of tokens | 483 | # Check if initializer_token is a single token or a sequence of tokens |
484 | if args.vectors_per_token % len(initializer_token_ids) != 0: | 484 | if len(initializer_token_ids) > 1: |
485 | raise ValueError( | 485 | raise ValueError( |
486 | f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") | 486 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") |
487 | 487 | ||
488 | initializer_token_ids = torch.tensor(initializer_token_ids) | 488 | initializer_token_ids = torch.tensor(initializer_token_ids) |
489 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 489 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
@@ -590,7 +590,7 @@ def main(): | |||
590 | sample_image_size=args.sample_image_size, | 590 | sample_image_size=args.sample_image_size, |
591 | sample_batch_size=args.sample_batch_size, | 591 | sample_batch_size=args.sample_batch_size, |
592 | sample_batches=args.sample_batches, | 592 | sample_batches=args.sample_batches, |
593 | seed=args.seed or torch.random.seed() | 593 | seed=args.seed |
594 | ) | 594 | ) |
595 | 595 | ||
596 | # Scheduler and math around the number of training steps. | 596 | # Scheduler and math around the number of training steps. |
@@ -620,8 +620,7 @@ def main(): | |||
620 | unet.eval() | 620 | unet.eval() |
621 | 621 | ||
622 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 622 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
623 | num_update_steps_per_epoch = math.ceil( | 623 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
624 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
625 | if overrode_max_train_steps: | 624 | if overrode_max_train_steps: |
626 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 625 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
627 | 626 | ||