summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--textual_dreambooth.py107
1 files changed, 38 insertions, 69 deletions
diff --git a/textual_dreambooth.py b/textual_dreambooth.py
index a46953d..c07d98b 100644
--- a/textual_dreambooth.py
+++ b/textual_dreambooth.py
@@ -68,20 +68,6 @@ def parse_args():
68 help="A token to use as initializer word." 68 help="A token to use as initializer word."
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--num_vec_per_token",
72 type=int,
73 default=1,
74 help=(
75 "The number of vectors used to represent the placeholder token. The higher the number, the better the"
76 " result at the cost of editability. This can be fixed by prompt editing."
77 ),
78 )
79 parser.add_argument(
80 "--initialize_rest_random",
81 action="store_true",
82 help="Initialize rest of the placeholder tokens with random."
83 )
84 parser.add_argument(
85 "--use_class_images", 71 "--use_class_images",
86 action="store_true", 72 action="store_true",
87 default=True, 73 default=True,
@@ -324,40 +310,6 @@ def make_grid(images, rows, cols):
324 return grid 310 return grid
325 311
326 312
327def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder):
328 assert args.num_vec_per_token >= len(token_ids)
329 placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)]
330
331 for placeholder_token in placeholder_tokens:
332 num_added_tokens = tokenizer.add_tokens(placeholder_token)
333 if num_added_tokens == 0:
334 raise ValueError(
335 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
336 " `placeholder_token` that is not already in the tokenizer."
337 )
338
339 placeholder_token = " ".join(placeholder_tokens)
340 placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
341
342 print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}")
343
344 text_encoder.resize_token_embeddings(len(tokenizer))
345 token_embeds = text_encoder.get_input_embeddings().weight.data
346
347 if args.initialize_rest_random:
348 # The idea is that the placeholder tokens form adjectives as in x x x white dog.
349 for i, placeholder_token_id in enumerate(placeholder_token_ids):
350 if len(placeholder_token_ids) - i < len(token_ids):
351 token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]]
352 else:
353 token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id])
354 else:
355 for i, placeholder_token_id in enumerate(placeholder_token_ids):
356 token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]]
357
358 return placeholder_token, placeholder_token_ids
359
360
361class Checkpointer: 313class Checkpointer:
362 def __init__( 314 def __init__(
363 self, 315 self,
@@ -367,7 +319,7 @@ class Checkpointer:
367 unet, 319 unet,
368 tokenizer, 320 tokenizer,
369 placeholder_token, 321 placeholder_token,
370 placeholder_token_ids, 322 placeholder_token_id,
371 output_dir, 323 output_dir,
372 sample_image_size, 324 sample_image_size,
373 sample_batches, 325 sample_batches,
@@ -380,7 +332,7 @@ class Checkpointer:
380 self.unet = unet 332 self.unet = unet
381 self.tokenizer = tokenizer 333 self.tokenizer = tokenizer
382 self.placeholder_token = placeholder_token 334 self.placeholder_token = placeholder_token
383 self.placeholder_token_ids = placeholder_token_ids 335 self.placeholder_token_id = placeholder_token_id
384 self.output_dir = output_dir 336 self.output_dir = output_dir
385 self.sample_image_size = sample_image_size 337 self.sample_image_size = sample_image_size
386 self.seed = seed or torch.random.seed() 338 self.seed = seed or torch.random.seed()
@@ -398,10 +350,8 @@ class Checkpointer:
398 unwrapped = self.accelerator.unwrap_model(text_encoder) 350 unwrapped = self.accelerator.unwrap_model(text_encoder)
399 351
400 # Save a checkpoint 352 # Save a checkpoint
401 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids] 353 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
402 learned_embeds_dict = {} 354 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
403 for i, placeholder_token in enumerate(self.placeholder_token.split(" ")):
404 learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu()
405 355
406 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 356 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
407 if path is not None: 357 if path is not None:
@@ -527,6 +477,24 @@ def main():
527 args.pretrained_model_name_or_path + '/tokenizer' 477 args.pretrained_model_name_or_path + '/tokenizer'
528 ) 478 )
529 479
480 # Add the placeholder token in tokenizer
481 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
482 if num_added_tokens == 0:
483 raise ValueError(
484 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
485 " `placeholder_token` that is not already in the tokenizer."
486 )
487
488 # Convert the initializer_token, placeholder_token to ids
489 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
490 # Check if initializer_token is a single token or a sequence of tokens
491 if len(initializer_token_ids) > 1:
492 raise ValueError(
493 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
494
495 initializer_token_ids = torch.tensor(initializer_token_ids)
496 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
497
530 # Load models and create wrapper for stable diffusion 498 # Load models and create wrapper for stable diffusion
531 text_encoder = CLIPTextModel.from_pretrained( 499 text_encoder = CLIPTextModel.from_pretrained(
532 args.pretrained_model_name_or_path + '/text_encoder', 500 args.pretrained_model_name_or_path + '/text_encoder',
@@ -544,16 +512,19 @@ def main():
544 slice_size = unet.config.attention_head_dim // 2 512 slice_size = unet.config.attention_head_dim // 2
545 unet.set_attention_slice(slice_size) 513 unet.set_attention_slice(slice_size)
546 514
547 token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 515 # Resize the token embeddings as we are adding new special tokens to the tokenizer
548 # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating. 516 text_encoder.resize_token_embeddings(len(tokenizer))
549 placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token(
550 args, token_ids, tokenizer, text_encoder)
551 517
552 # if args.resume_checkpoint is not None: 518 # Initialise the newly added placeholder token with the embeddings of the initializer token
553 # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ 519 token_embeds = text_encoder.get_input_embeddings().weight.data
554 # args.placeholder_token] 520
555 # else: 521 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
556 # token_embeds[placeholder_token_id] = initializer_token_embeddings 522
523 if args.resume_checkpoint is not None:
524 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
525 args.placeholder_token]
526 else:
527 token_embeds[placeholder_token_id] = initializer_token_embeddings
557 528
558 # Freeze vae and unet 529 # Freeze vae and unet
559 freeze_params(vae.parameters()) 530 freeze_params(vae.parameters())
@@ -627,7 +598,7 @@ def main():
627 data_file=args.train_data_file, 598 data_file=args.train_data_file,
628 batch_size=args.train_batch_size, 599 batch_size=args.train_batch_size,
629 tokenizer=tokenizer, 600 tokenizer=tokenizer,
630 instance_identifier=placeholder_token, 601 instance_identifier=args.placeholder_token,
631 class_identifier=args.initializer_token if args.use_class_images else None, 602 class_identifier=args.initializer_token if args.use_class_images else None,
632 class_subdir="ti_cls", 603 class_subdir="ti_cls",
633 size=args.resolution, 604 size=args.resolution,
@@ -690,7 +661,7 @@ def main():
690 unet=unet, 661 unet=unet,
691 tokenizer=tokenizer, 662 tokenizer=tokenizer,
692 placeholder_token=args.placeholder_token, 663 placeholder_token=args.placeholder_token,
693 placeholder_token_ids=placeholder_token_ids, 664 placeholder_token_id=placeholder_token_id,
694 output_dir=basepath, 665 output_dir=basepath,
695 sample_image_size=args.sample_image_size, 666 sample_image_size=args.sample_image_size,
696 sample_batch_size=args.sample_batch_size, 667 sample_batch_size=args.sample_batch_size,
@@ -823,10 +794,8 @@ def main():
823 else: 794 else:
824 grads = text_encoder.get_input_embeddings().weight.grad 795 grads = text_encoder.get_input_embeddings().weight.grad
825 # Get the index for tokens that we want to zero the grads for 796 # Get the index for tokens that we want to zero the grads for
826 grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0] 797 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
827 for i in range(1, len(placeholder_token_ids)): 798 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
828 grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i])
829 grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0)
830 799
831 optimizer.step() 800 optimizer.step()
832 if not accelerator.optimizer_step_was_skipped: 801 if not accelerator.optimizer_step_was_skipped: