From 633d890e4964e070be9b0a5b299c2f2e51d4b055 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 12:27:53 +0200 Subject: Upstream updates; better handling of textual embedding --- textual_inversion.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 0d5a742..69d9c7f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -55,9 +55,9 @@ def parse_args(): help="A CSV file containing the training data." ) parser.add_argument( - "--placeholder_token", + "--instance_identifier", type=str, - default="<*>", + default=None, help="A token to use as a placeholder for the concept.", ) parser.add_argument( @@ -66,6 +66,12 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--placeholder_token", + type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) parser.add_argument( "--initializer_token", type=str, @@ -333,6 +339,7 @@ class Checkpointer: unet, tokenizer, text_encoder, + instance_identifier, placeholder_token, placeholder_token_id, output_dir: Path, @@ -347,6 +354,7 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder + self.instance_identifier = instance_identifier self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -413,7 +421,7 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.placeholder_token) + prompt = [prompt.format(self.instance_identifier) for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] @@ -428,7 +436,7 @@ class Checkpointer: eta=eta, num_inference_steps=num_inference_steps, output_type='pil' - )["sample"] + ).images all_samples += samples @@ -480,28 +488,26 @@ def main(): if args.seed is not None: set_seed(args.seed) + args.instance_identifier = args.instance_identifier.format(args.placeholder_token) + # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") + initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - # Check if initializer_token is a single token or a sequence of tokens - if len(initializer_token_ids) > 1: - raise ValueError( - f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") + print(f"Re-using existing token {args.placeholder_token}.") + else: + print(f"Training new token {args.placeholder_token}.") - initializer_token_ids = torch.tensor(initializer_token_ids) placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion @@ -602,7 +608,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_identifier=args.placeholder_token, + instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, @@ -730,6 +736,7 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, + instance_identifier=args.instance_identifier, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, -- cgit v1.2.3-54-g00ecf