summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
committerVolpeon <git@volpeon.ink>2022-12-31 12:58:54 +0100
commit6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch)
tree52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /train_ti.py
parentMisc improvements (diff)
downloadtextual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2
textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip
Added multi-vector embeddings
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py88
1 files changed, 44 insertions, 44 deletions
diff --git a/train_ti.py b/train_ti.py
index 088c1a6..69d15ea 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler,
16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt 17import matplotlib.pyplot as plt
18from tqdm.auto import tqdm 18from tqdm.auto import tqdm
19from transformers import CLIPTextModel, CLIPTokenizer 19from transformers import CLIPTextModel
20from slugify import slugify 20from slugify import slugify
21 21
22from common import load_text_embeddings, load_text_embedding, load_config 22from common import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import CSVDataModule, CSVDataItem 24from data.csv import CSVDataModule, CSVDataItem
25from training.optimization import get_one_cycle_schedule 25from training.optimization import get_one_cycle_schedule
26from training.lr import LRFinder 26from training.lr import LRFinder
27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args 27from training.util import AverageMeter, CheckpointerBase, save_args
28from models.clip.embeddings import patch_managed_embeddings
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30from models.clip.tokenizer import MultiCLIPTokenizer
30 31
31logger = get_logger(__name__) 32logger = get_logger(__name__)
32 33
@@ -81,6 +82,12 @@ def parse_args():
81 help="A token to use as initializer word." 82 help="A token to use as initializer word."
82 ) 83 )
83 parser.add_argument( 84 parser.add_argument(
85 "--num_vectors",
86 type=int,
87 nargs='*',
88 help="Number of vectors per embedding."
89 )
90 parser.add_argument(
84 "--num_class_images", 91 "--num_class_images",
85 type=int, 92 type=int,
86 default=1, 93 default=1,
@@ -360,8 +367,17 @@ def parse_args():
360 if len(args.placeholder_token) == 0: 367 if len(args.placeholder_token) == 0:
361 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 368 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
362 369
370 if args.num_vectors is None:
371 args.num_vectors = 1
372
373 if isinstance(args.num_vectors, int):
374 args.num_vectors = [args.num_vectors] * len(args.initializer_token)
375
363 if len(args.placeholder_token) != len(args.initializer_token): 376 if len(args.placeholder_token) != len(args.initializer_token):
364 raise ValueError("You must specify --placeholder_token") 377 raise ValueError("--placeholder_token and --initializer_token must have the same number of items")
378
379 if len(args.placeholder_token) != len(args.num_vectors):
380 raise ValueError("--placeholder_token and --num_vectors must have the same number of items")
365 381
366 if isinstance(args.collection, str): 382 if isinstance(args.collection, str):
367 args.collection = [args.collection] 383 args.collection = [args.collection]
@@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase):
386 tokenizer, 402 tokenizer,
387 text_encoder, 403 text_encoder,
388 scheduler, 404 scheduler,
389 placeholder_token, 405 new_tokens,
390 placeholder_token_id,
391 output_dir: Path, 406 output_dir: Path,
392 sample_image_size, 407 sample_image_size,
393 sample_batches, 408 sample_batches,
@@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase):
397 super().__init__( 412 super().__init__(
398 datamodule=datamodule, 413 datamodule=datamodule,
399 output_dir=output_dir, 414 output_dir=output_dir,
400 placeholder_token=placeholder_token,
401 placeholder_token_id=placeholder_token_id,
402 sample_image_size=sample_image_size, 415 sample_image_size=sample_image_size,
403 seed=seed or torch.random.seed(), 416 seed=seed or torch.random.seed(),
404 sample_batches=sample_batches, 417 sample_batches=sample_batches,
@@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase):
412 self.tokenizer = tokenizer 425 self.tokenizer = tokenizer
413 self.text_encoder = text_encoder 426 self.text_encoder = text_encoder
414 self.scheduler = scheduler 427 self.scheduler = scheduler
428 self.new_tokens = new_tokens
415 429
416 @torch.no_grad() 430 @torch.no_grad()
417 def checkpoint(self, step, postfix): 431 def checkpoint(self, step, postfix):
@@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase):
422 436
423 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 437 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
424 438
425 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 439 for new_token in self.new_tokens:
426 # Save a checkpoint 440 text_encoder.text_model.embeddings.save_embed(
427 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] 441 new_token.multi_ids,
428 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 442 f"{slugify(new_token.token)}_{step}_{postfix}.bin"
429 443 )
430 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
431 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
432 444
433 del text_encoder 445 del text_encoder
434 del learned_embeds 446 del learned_embeds
@@ -487,9 +499,9 @@ def main():
487 499
488 # Load the tokenizer and add the placeholder token as a additional special token 500 # Load the tokenizer and add the placeholder token as a additional special token
489 if args.tokenizer_name: 501 if args.tokenizer_name:
490 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 502 tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name)
491 elif args.pretrained_model_name_or_path: 503 elif args.pretrained_model_name_or_path:
492 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 504 tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
493 505
494 # Load models and create wrapper for stable diffusion 506 # Load models and create wrapper for stable diffusion
495 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 507 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
@@ -507,45 +519,33 @@ def main():
507 unet.enable_gradient_checkpointing() 519 unet.enable_gradient_checkpointing()
508 text_encoder.gradient_checkpointing_enable() 520 text_encoder.gradient_checkpointing_enable()
509 521
522 embeddings = patch_managed_embeddings(text_encoder)
523
510 if args.embeddings_dir is not None: 524 if args.embeddings_dir is not None:
511 embeddings_dir = Path(args.embeddings_dir) 525 embeddings_dir = Path(args.embeddings_dir)
512 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 526 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
513 raise ValueError("--embeddings_dir must point to an existing directory") 527 raise ValueError("--embeddings_dir must point to an existing directory")
514 added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) 528
529 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
515 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") 530 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
516 531
517 # Convert the initializer_token, placeholder_token to ids 532 # Convert the initializer_token, placeholder_token to ids
518 initializer_token_ids = torch.stack([ 533 initializer_token_ids = [
519 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 534 tokenizer.encode(token, add_special_tokens=False)
520 for token in args.initializer_token 535 for token in args.initializer_token
521 ]) 536 ]
522
523 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
524 print(f"Added {num_added_tokens} new tokens.")
525
526 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
527 537
528 # Resize the token embeddings as we are adding new special tokens to the tokenizer 538 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
529 text_encoder.resize_token_embeddings(len(tokenizer))
530 539
531 # Initialise the newly added placeholder token with the embeddings of the initializer token 540 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids):
532 token_embeds = text_encoder.get_input_embeddings().weight.data 541 embeddings.add_embed(new_token.placeholder_id)
542 embeddings.add_embed(new_token.multi_ids, init_ids)
533 543
534 if args.resume_from is not None: 544 print(f"Added {len(new_tokens)} new tokens.")
535 resumepath = Path(args.resume_from).joinpath("checkpoints")
536
537 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
538 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
539
540 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
541 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
542 token_embeds[token_id] = embeddings
543 545
544 vae.requires_grad_(False) 546 vae.requires_grad_(False)
545 unet.requires_grad_(False) 547 unet.requires_grad_(False)
546 548
547 patch_trainable_embeddings(text_encoder, placeholder_token_id)
548
549 text_encoder.text_model.encoder.requires_grad_(False) 549 text_encoder.text_model.encoder.requires_grad_(False)
550 text_encoder.text_model.final_layer_norm.requires_grad_(False) 550 text_encoder.text_model.final_layer_norm.requires_grad_(False)
551 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 551 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
@@ -575,7 +575,7 @@ def main():
575 575
576 # Initialize the optimizer 576 # Initialize the optimizer
577 optimizer = optimizer_class( 577 optimizer = optimizer_class(
578 text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings 578 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings
579 lr=args.learning_rate, 579 lr=args.learning_rate,
580 betas=(args.adam_beta1, args.adam_beta2), 580 betas=(args.adam_beta1, args.adam_beta2),
581 weight_decay=args.adam_weight_decay, 581 weight_decay=args.adam_weight_decay,
@@ -816,6 +816,7 @@ def main():
816 config = vars(args).copy() 816 config = vars(args).copy()
817 config["initializer_token"] = " ".join(config["initializer_token"]) 817 config["initializer_token"] = " ".join(config["initializer_token"])
818 config["placeholder_token"] = " ".join(config["placeholder_token"]) 818 config["placeholder_token"] = " ".join(config["placeholder_token"])
819 config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]])
819 if config["collection"] is not None: 820 if config["collection"] is not None:
820 config["collection"] = " ".join(config["collection"]) 821 config["collection"] = " ".join(config["collection"])
821 if config["exclude_collections"] is not None: 822 if config["exclude_collections"] is not None:
@@ -852,8 +853,7 @@ def main():
852 tokenizer=tokenizer, 853 tokenizer=tokenizer,
853 text_encoder=text_encoder, 854 text_encoder=text_encoder,
854 scheduler=checkpoint_scheduler, 855 scheduler=checkpoint_scheduler,
855 placeholder_token=args.placeholder_token, 856 new_tokens=new_tokens,
856 placeholder_token_id=placeholder_token_id,
857 output_dir=basepath, 857 output_dir=basepath,
858 sample_image_size=args.sample_image_size, 858 sample_image_size=args.sample_image_size,
859 sample_batch_size=args.sample_batch_size, 859 sample_batch_size=args.sample_batch_size,