From f07104c79fe1e3574fd0fb11f8bd400c96de9def Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 20:42:40 +0200 Subject: Multi-vector TI was broken --- textual_dreambooth.py | 107 ++++++++++++++++++-------------------------------- 1 file changed, 38 insertions(+), 69 deletions(-) diff --git a/textual_dreambooth.py b/textual_dreambooth.py index a46953d..c07d98b 100644 --- a/textual_dreambooth.py +++ b/textual_dreambooth.py @@ -67,20 +67,6 @@ def parse_args(): default=None, help="A token to use as initializer word." ) - parser.add_argument( - "--num_vec_per_token", - type=int, - default=1, - help=( - "The number of vectors used to represent the placeholder token. The higher the number, the better the" - " result at the cost of editability. This can be fixed by prompt editing." - ), - ) - parser.add_argument( - "--initialize_rest_random", - action="store_true", - help="Initialize rest of the placeholder tokens with random." - ) parser.add_argument( "--use_class_images", action="store_true", @@ -324,40 +310,6 @@ def make_grid(images, rows, cols): return grid -def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder): - assert args.num_vec_per_token >= len(token_ids) - placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)] - - for placeholder_token in placeholder_tokens: - num_added_tokens = tokenizer.add_tokens(placeholder_token) - if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {placeholder_token}. Please pass a different" - " `placeholder_token` that is not already in the tokenizer." - ) - - placeholder_token = " ".join(placeholder_tokens) - placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False) - - print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}") - - text_encoder.resize_token_embeddings(len(tokenizer)) - token_embeds = text_encoder.get_input_embeddings().weight.data - - if args.initialize_rest_random: - # The idea is that the placeholder tokens form adjectives as in x x x white dog. - for i, placeholder_token_id in enumerate(placeholder_token_ids): - if len(placeholder_token_ids) - i < len(token_ids): - token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] - else: - token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id]) - else: - for i, placeholder_token_id in enumerate(placeholder_token_ids): - token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] - - return placeholder_token, placeholder_token_ids - - class Checkpointer: def __init__( self, @@ -367,7 +319,7 @@ class Checkpointer: unet, tokenizer, placeholder_token, - placeholder_token_ids, + placeholder_token_id, output_dir, sample_image_size, sample_batches, @@ -380,7 +332,7 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.placeholder_token = placeholder_token - self.placeholder_token_ids = placeholder_token_ids + self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() @@ -398,10 +350,8 @@ class Checkpointer: unwrapped = self.accelerator.unwrap_model(text_encoder) # Save a checkpoint - learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids] - learned_embeds_dict = {} - for i, placeholder_token in enumerate(self.placeholder_token.split(" ")): - learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu() + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) if path is not None: @@ -527,6 +477,24 @@ def main(): args.pretrained_model_name_or_path + '/tokenizer' ) + # Add the placeholder token in tokenizer + num_added_tokens = tokenizer.add_tokens(args.placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(initializer_token_ids) > 1: + raise ValueError( + f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") + + initializer_token_ids = torch.tensor(initializer_token_ids) + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path + '/text_encoder', @@ -544,16 +512,19 @@ def main(): slice_size = unet.config.attention_head_dim // 2 unet.set_attention_slice(slice_size) - token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) - # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating. - placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token( - args, token_ids, tokenizer, text_encoder) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) - # if args.resume_checkpoint is not None: - # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ - # args.placeholder_token] - # else: - # token_embeds[placeholder_token_id] = initializer_token_embeddings + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + + if args.resume_checkpoint is not None: + token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ + args.placeholder_token] + else: + token_embeds[placeholder_token_id] = initializer_token_embeddings # Freeze vae and unet freeze_params(vae.parameters()) @@ -627,7 +598,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, - instance_identifier=placeholder_token, + instance_identifier=args.placeholder_token, class_identifier=args.initializer_token if args.use_class_images else None, class_subdir="ti_cls", size=args.resolution, @@ -690,7 +661,7 @@ def main(): unet=unet, tokenizer=tokenizer, placeholder_token=args.placeholder_token, - placeholder_token_ids=placeholder_token_ids, + placeholder_token_id=placeholder_token_id, output_dir=basepath, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, @@ -823,10 +794,8 @@ def main(): else: grads = text_encoder.get_input_embeddings().weight.grad # Get the index for tokens that we want to zero the grads for - grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0] - for i in range(1, len(placeholder_token_ids)): - grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]) - grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0) + index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id + grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) optimizer.step() if not accelerator.optimizer_step_was_skipped: -- cgit v1.2.3-54-g00ecf