diff options
-rw-r--r-- | textual_dreambooth.py | 107 |
1 files changed, 38 insertions, 69 deletions
diff --git a/textual_dreambooth.py b/textual_dreambooth.py index a46953d..c07d98b 100644 --- a/textual_dreambooth.py +++ b/textual_dreambooth.py | |||
@@ -68,20 +68,6 @@ def parse_args(): | |||
68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
69 | ) | 69 | ) |
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--num_vec_per_token", | ||
72 | type=int, | ||
73 | default=1, | ||
74 | help=( | ||
75 | "The number of vectors used to represent the placeholder token. The higher the number, the better the" | ||
76 | " result at the cost of editability. This can be fixed by prompt editing." | ||
77 | ), | ||
78 | ) | ||
79 | parser.add_argument( | ||
80 | "--initialize_rest_random", | ||
81 | action="store_true", | ||
82 | help="Initialize rest of the placeholder tokens with random." | ||
83 | ) | ||
84 | parser.add_argument( | ||
85 | "--use_class_images", | 71 | "--use_class_images", |
86 | action="store_true", | 72 | action="store_true", |
87 | default=True, | 73 | default=True, |
@@ -324,40 +310,6 @@ def make_grid(images, rows, cols): | |||
324 | return grid | 310 | return grid |
325 | 311 | ||
326 | 312 | ||
327 | def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder): | ||
328 | assert args.num_vec_per_token >= len(token_ids) | ||
329 | placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)] | ||
330 | |||
331 | for placeholder_token in placeholder_tokens: | ||
332 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | ||
333 | if num_added_tokens == 0: | ||
334 | raise ValueError( | ||
335 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | ||
336 | " `placeholder_token` that is not already in the tokenizer." | ||
337 | ) | ||
338 | |||
339 | placeholder_token = " ".join(placeholder_tokens) | ||
340 | placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False) | ||
341 | |||
342 | print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}") | ||
343 | |||
344 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
345 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
346 | |||
347 | if args.initialize_rest_random: | ||
348 | # The idea is that the placeholder tokens form adjectives as in x x x white dog. | ||
349 | for i, placeholder_token_id in enumerate(placeholder_token_ids): | ||
350 | if len(placeholder_token_ids) - i < len(token_ids): | ||
351 | token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] | ||
352 | else: | ||
353 | token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id]) | ||
354 | else: | ||
355 | for i, placeholder_token_id in enumerate(placeholder_token_ids): | ||
356 | token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]] | ||
357 | |||
358 | return placeholder_token, placeholder_token_ids | ||
359 | |||
360 | |||
361 | class Checkpointer: | 313 | class Checkpointer: |
362 | def __init__( | 314 | def __init__( |
363 | self, | 315 | self, |
@@ -367,7 +319,7 @@ class Checkpointer: | |||
367 | unet, | 319 | unet, |
368 | tokenizer, | 320 | tokenizer, |
369 | placeholder_token, | 321 | placeholder_token, |
370 | placeholder_token_ids, | 322 | placeholder_token_id, |
371 | output_dir, | 323 | output_dir, |
372 | sample_image_size, | 324 | sample_image_size, |
373 | sample_batches, | 325 | sample_batches, |
@@ -380,7 +332,7 @@ class Checkpointer: | |||
380 | self.unet = unet | 332 | self.unet = unet |
381 | self.tokenizer = tokenizer | 333 | self.tokenizer = tokenizer |
382 | self.placeholder_token = placeholder_token | 334 | self.placeholder_token = placeholder_token |
383 | self.placeholder_token_ids = placeholder_token_ids | 335 | self.placeholder_token_id = placeholder_token_id |
384 | self.output_dir = output_dir | 336 | self.output_dir = output_dir |
385 | self.sample_image_size = sample_image_size | 337 | self.sample_image_size = sample_image_size |
386 | self.seed = seed or torch.random.seed() | 338 | self.seed = seed or torch.random.seed() |
@@ -398,10 +350,8 @@ class Checkpointer: | |||
398 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 350 | unwrapped = self.accelerator.unwrap_model(text_encoder) |
399 | 351 | ||
400 | # Save a checkpoint | 352 | # Save a checkpoint |
401 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids] | 353 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
402 | learned_embeds_dict = {} | 354 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
403 | for i, placeholder_token in enumerate(self.placeholder_token.split(" ")): | ||
404 | learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu() | ||
405 | 355 | ||
406 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 356 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
407 | if path is not None: | 357 | if path is not None: |
@@ -527,6 +477,24 @@ def main(): | |||
527 | args.pretrained_model_name_or_path + '/tokenizer' | 477 | args.pretrained_model_name_or_path + '/tokenizer' |
528 | ) | 478 | ) |
529 | 479 | ||
480 | # Add the placeholder token in tokenizer | ||
481 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
482 | if num_added_tokens == 0: | ||
483 | raise ValueError( | ||
484 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | ||
485 | " `placeholder_token` that is not already in the tokenizer." | ||
486 | ) | ||
487 | |||
488 | # Convert the initializer_token, placeholder_token to ids | ||
489 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
490 | # Check if initializer_token is a single token or a sequence of tokens | ||
491 | if len(initializer_token_ids) > 1: | ||
492 | raise ValueError( | ||
493 | f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.") | ||
494 | |||
495 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
496 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
497 | |||
530 | # Load models and create wrapper for stable diffusion | 498 | # Load models and create wrapper for stable diffusion |
531 | text_encoder = CLIPTextModel.from_pretrained( | 499 | text_encoder = CLIPTextModel.from_pretrained( |
532 | args.pretrained_model_name_or_path + '/text_encoder', | 500 | args.pretrained_model_name_or_path + '/text_encoder', |
@@ -544,16 +512,19 @@ def main(): | |||
544 | slice_size = unet.config.attention_head_dim // 2 | 512 | slice_size = unet.config.attention_head_dim // 2 |
545 | unet.set_attention_slice(slice_size) | 513 | unet.set_attention_slice(slice_size) |
546 | 514 | ||
547 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 515 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
548 | # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating. | 516 | text_encoder.resize_token_embeddings(len(tokenizer)) |
549 | placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token( | ||
550 | args, token_ids, tokenizer, text_encoder) | ||
551 | 517 | ||
552 | # if args.resume_checkpoint is not None: | 518 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
553 | # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | 519 | token_embeds = text_encoder.get_input_embeddings().weight.data |
554 | # args.placeholder_token] | 520 | |
555 | # else: | 521 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
556 | # token_embeds[placeholder_token_id] = initializer_token_embeddings | 522 | |
523 | if args.resume_checkpoint is not None: | ||
524 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | ||
525 | args.placeholder_token] | ||
526 | else: | ||
527 | token_embeds[placeholder_token_id] = initializer_token_embeddings | ||
557 | 528 | ||
558 | # Freeze vae and unet | 529 | # Freeze vae and unet |
559 | freeze_params(vae.parameters()) | 530 | freeze_params(vae.parameters()) |
@@ -627,7 +598,7 @@ def main(): | |||
627 | data_file=args.train_data_file, | 598 | data_file=args.train_data_file, |
628 | batch_size=args.train_batch_size, | 599 | batch_size=args.train_batch_size, |
629 | tokenizer=tokenizer, | 600 | tokenizer=tokenizer, |
630 | instance_identifier=placeholder_token, | 601 | instance_identifier=args.placeholder_token, |
631 | class_identifier=args.initializer_token if args.use_class_images else None, | 602 | class_identifier=args.initializer_token if args.use_class_images else None, |
632 | class_subdir="ti_cls", | 603 | class_subdir="ti_cls", |
633 | size=args.resolution, | 604 | size=args.resolution, |
@@ -690,7 +661,7 @@ def main(): | |||
690 | unet=unet, | 661 | unet=unet, |
691 | tokenizer=tokenizer, | 662 | tokenizer=tokenizer, |
692 | placeholder_token=args.placeholder_token, | 663 | placeholder_token=args.placeholder_token, |
693 | placeholder_token_ids=placeholder_token_ids, | 664 | placeholder_token_id=placeholder_token_id, |
694 | output_dir=basepath, | 665 | output_dir=basepath, |
695 | sample_image_size=args.sample_image_size, | 666 | sample_image_size=args.sample_image_size, |
696 | sample_batch_size=args.sample_batch_size, | 667 | sample_batch_size=args.sample_batch_size, |
@@ -823,10 +794,8 @@ def main(): | |||
823 | else: | 794 | else: |
824 | grads = text_encoder.get_input_embeddings().weight.grad | 795 | grads = text_encoder.get_input_embeddings().weight.grad |
825 | # Get the index for tokens that we want to zero the grads for | 796 | # Get the index for tokens that we want to zero the grads for |
826 | grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0] | 797 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id |
827 | for i in range(1, len(placeholder_token_ids)): | 798 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) |
828 | grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]) | ||
829 | grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0) | ||
830 | 799 | ||
831 | optimizer.step() | 800 | optimizer.step() |
832 | if not accelerator.optimizer_step_was_skipped: | 801 | if not accelerator.optimizer_step_was_skipped: |