From 300deaa789a0321f32d5e7f04d9860eaa258110e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 19:22:22 +0200 Subject: Add Textual Inversion with class dataset (a la Dreambooth) --- textual_inversion.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index d842288..7919ebd 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -230,7 +230,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=40, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -329,7 +329,7 @@ class Checkpointer: self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed + self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @@ -481,9 +481,9 @@ def main(): # Convert the initializer_token, placeholder_token to ids initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) # Check if initializer_token is a single token or a sequence of tokens - if args.vectors_per_token % len(initializer_token_ids) != 0: + if len(initializer_token_ids) > 1: raise ValueError( - f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") + f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") initializer_token_ids = torch.tensor(initializer_token_ids) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) @@ -590,7 +590,7 @@ def main(): sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - seed=args.seed or torch.random.seed() + seed=args.seed ) # Scheduler and math around the number of training steps. @@ -620,8 +620,7 @@ def main(): unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch -- cgit v1.2.3-54-g00ecf