diff options
author | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-31 12:58:54 +0100 |
commit | 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 (patch) | |
tree | 52f10e5b7c8b1849fcd5c1210ca1cae21e2ac49e /train_ti.py | |
parent | Misc improvements (diff) | |
download | textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.gz textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.tar.bz2 textual-inversion-diff-6b58e9de249e872bd2d83e5916e6c633f52cfbb8.zip |
Added multi-vector embeddings
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 88 |
1 files changed, 44 insertions, 44 deletions
diff --git a/train_ti.py b/train_ti.py index 088c1a6..69d15ea 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -16,17 +16,18 @@ from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, | |||
16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
17 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
18 | from tqdm.auto import tqdm | 18 | from tqdm.auto import tqdm |
19 | from transformers import CLIPTextModel, CLIPTokenizer | 19 | from transformers import CLIPTextModel |
20 | from slugify import slugify | 20 | from slugify import slugify |
21 | 21 | ||
22 | from common import load_text_embeddings, load_text_embedding, load_config | 22 | from common import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import CSVDataModule, CSVDataItem | 24 | from data.csv import CSVDataModule, CSVDataItem |
25 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
26 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
27 | from training.ti import patch_trainable_embeddings | ||
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 27 | from training.util import AverageMeter, CheckpointerBase, save_args |
28 | from models.clip.embeddings import patch_managed_embeddings | ||
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
30 | 31 | ||
31 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
32 | 33 | ||
@@ -81,6 +82,12 @@ def parse_args(): | |||
81 | help="A token to use as initializer word." | 82 | help="A token to use as initializer word." |
82 | ) | 83 | ) |
83 | parser.add_argument( | 84 | parser.add_argument( |
85 | "--num_vectors", | ||
86 | type=int, | ||
87 | nargs='*', | ||
88 | help="Number of vectors per embedding." | ||
89 | ) | ||
90 | parser.add_argument( | ||
84 | "--num_class_images", | 91 | "--num_class_images", |
85 | type=int, | 92 | type=int, |
86 | default=1, | 93 | default=1, |
@@ -360,8 +367,17 @@ def parse_args(): | |||
360 | if len(args.placeholder_token) == 0: | 367 | if len(args.placeholder_token) == 0: |
361 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | 368 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] |
362 | 369 | ||
370 | if args.num_vectors is None: | ||
371 | args.num_vectors = 1 | ||
372 | |||
373 | if isinstance(args.num_vectors, int): | ||
374 | args.num_vectors = [args.num_vectors] * len(args.initializer_token) | ||
375 | |||
363 | if len(args.placeholder_token) != len(args.initializer_token): | 376 | if len(args.placeholder_token) != len(args.initializer_token): |
364 | raise ValueError("You must specify --placeholder_token") | 377 | raise ValueError("--placeholder_token and --initializer_token must have the same number of items") |
378 | |||
379 | if len(args.placeholder_token) != len(args.num_vectors): | ||
380 | raise ValueError("--placeholder_token and --num_vectors must have the same number of items") | ||
365 | 381 | ||
366 | if isinstance(args.collection, str): | 382 | if isinstance(args.collection, str): |
367 | args.collection = [args.collection] | 383 | args.collection = [args.collection] |
@@ -386,8 +402,7 @@ class Checkpointer(CheckpointerBase): | |||
386 | tokenizer, | 402 | tokenizer, |
387 | text_encoder, | 403 | text_encoder, |
388 | scheduler, | 404 | scheduler, |
389 | placeholder_token, | 405 | new_tokens, |
390 | placeholder_token_id, | ||
391 | output_dir: Path, | 406 | output_dir: Path, |
392 | sample_image_size, | 407 | sample_image_size, |
393 | sample_batches, | 408 | sample_batches, |
@@ -397,8 +412,6 @@ class Checkpointer(CheckpointerBase): | |||
397 | super().__init__( | 412 | super().__init__( |
398 | datamodule=datamodule, | 413 | datamodule=datamodule, |
399 | output_dir=output_dir, | 414 | output_dir=output_dir, |
400 | placeholder_token=placeholder_token, | ||
401 | placeholder_token_id=placeholder_token_id, | ||
402 | sample_image_size=sample_image_size, | 415 | sample_image_size=sample_image_size, |
403 | seed=seed or torch.random.seed(), | 416 | seed=seed or torch.random.seed(), |
404 | sample_batches=sample_batches, | 417 | sample_batches=sample_batches, |
@@ -412,6 +425,7 @@ class Checkpointer(CheckpointerBase): | |||
412 | self.tokenizer = tokenizer | 425 | self.tokenizer = tokenizer |
413 | self.text_encoder = text_encoder | 426 | self.text_encoder = text_encoder |
414 | self.scheduler = scheduler | 427 | self.scheduler = scheduler |
428 | self.new_tokens = new_tokens | ||
415 | 429 | ||
416 | @torch.no_grad() | 430 | @torch.no_grad() |
417 | def checkpoint(self, step, postfix): | 431 | def checkpoint(self, step, postfix): |
@@ -422,13 +436,11 @@ class Checkpointer(CheckpointerBase): | |||
422 | 436 | ||
423 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 437 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
424 | 438 | ||
425 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 439 | for new_token in self.new_tokens: |
426 | # Save a checkpoint | 440 | text_encoder.text_model.embeddings.save_embed( |
427 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] | 441 | new_token.multi_ids, |
428 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 442 | f"{slugify(new_token.token)}_{step}_{postfix}.bin" |
429 | 443 | ) | |
430 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | ||
431 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
432 | 444 | ||
433 | del text_encoder | 445 | del text_encoder |
434 | del learned_embeds | 446 | del learned_embeds |
@@ -487,9 +499,9 @@ def main(): | |||
487 | 499 | ||
488 | # Load the tokenizer and add the placeholder token as a additional special token | 500 | # Load the tokenizer and add the placeholder token as a additional special token |
489 | if args.tokenizer_name: | 501 | if args.tokenizer_name: |
490 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 502 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.tokenizer_name) |
491 | elif args.pretrained_model_name_or_path: | 503 | elif args.pretrained_model_name_or_path: |
492 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 504 | tokenizer = MultiCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
493 | 505 | ||
494 | # Load models and create wrapper for stable diffusion | 506 | # Load models and create wrapper for stable diffusion |
495 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 507 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
@@ -507,45 +519,33 @@ def main(): | |||
507 | unet.enable_gradient_checkpointing() | 519 | unet.enable_gradient_checkpointing() |
508 | text_encoder.gradient_checkpointing_enable() | 520 | text_encoder.gradient_checkpointing_enable() |
509 | 521 | ||
522 | embeddings = patch_managed_embeddings(text_encoder) | ||
523 | |||
510 | if args.embeddings_dir is not None: | 524 | if args.embeddings_dir is not None: |
511 | embeddings_dir = Path(args.embeddings_dir) | 525 | embeddings_dir = Path(args.embeddings_dir) |
512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 526 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
513 | raise ValueError("--embeddings_dir must point to an existing directory") | 527 | raise ValueError("--embeddings_dir must point to an existing directory") |
514 | added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 528 | |
529 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
515 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 530 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") |
516 | 531 | ||
517 | # Convert the initializer_token, placeholder_token to ids | 532 | # Convert the initializer_token, placeholder_token to ids |
518 | initializer_token_ids = torch.stack([ | 533 | initializer_token_ids = [ |
519 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | 534 | tokenizer.encode(token, add_special_tokens=False) |
520 | for token in args.initializer_token | 535 | for token in args.initializer_token |
521 | ]) | 536 | ] |
522 | |||
523 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
524 | print(f"Added {num_added_tokens} new tokens.") | ||
525 | |||
526 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
527 | 537 | ||
528 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 538 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
529 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
530 | 539 | ||
531 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 540 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): |
532 | token_embeds = text_encoder.get_input_embeddings().weight.data | 541 | embeddings.add_embed(new_token.placeholder_id) |
542 | embeddings.add_embed(new_token.multi_ids, init_ids) | ||
533 | 543 | ||
534 | if args.resume_from is not None: | 544 | print(f"Added {len(new_tokens)} new tokens.") |
535 | resumepath = Path(args.resume_from).joinpath("checkpoints") | ||
536 | |||
537 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | ||
538 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) | ||
539 | |||
540 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
541 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
542 | token_embeds[token_id] = embeddings | ||
543 | 545 | ||
544 | vae.requires_grad_(False) | 546 | vae.requires_grad_(False) |
545 | unet.requires_grad_(False) | 547 | unet.requires_grad_(False) |
546 | 548 | ||
547 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | ||
548 | |||
549 | text_encoder.text_model.encoder.requires_grad_(False) | 549 | text_encoder.text_model.encoder.requires_grad_(False) |
550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) | 550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) | 551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
@@ -575,7 +575,7 @@ def main(): | |||
575 | 575 | ||
576 | # Initialize the optimizer | 576 | # Initialize the optimizer |
577 | optimizer = optimizer_class( | 577 | optimizer = optimizer_class( |
578 | text_encoder.text_model.embeddings.trainable_embedding.parameters(), # only optimize the embeddings | 578 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), # only optimize the embeddings |
579 | lr=args.learning_rate, | 579 | lr=args.learning_rate, |
580 | betas=(args.adam_beta1, args.adam_beta2), | 580 | betas=(args.adam_beta1, args.adam_beta2), |
581 | weight_decay=args.adam_weight_decay, | 581 | weight_decay=args.adam_weight_decay, |
@@ -816,6 +816,7 @@ def main(): | |||
816 | config = vars(args).copy() | 816 | config = vars(args).copy() |
817 | config["initializer_token"] = " ".join(config["initializer_token"]) | 817 | config["initializer_token"] = " ".join(config["initializer_token"]) |
818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | 818 | config["placeholder_token"] = " ".join(config["placeholder_token"]) |
819 | config["num_vectors"] = " ".join([str(n) for n in config["num_vectors"]]) | ||
819 | if config["collection"] is not None: | 820 | if config["collection"] is not None: |
820 | config["collection"] = " ".join(config["collection"]) | 821 | config["collection"] = " ".join(config["collection"]) |
821 | if config["exclude_collections"] is not None: | 822 | if config["exclude_collections"] is not None: |
@@ -852,8 +853,7 @@ def main(): | |||
852 | tokenizer=tokenizer, | 853 | tokenizer=tokenizer, |
853 | text_encoder=text_encoder, | 854 | text_encoder=text_encoder, |
854 | scheduler=checkpoint_scheduler, | 855 | scheduler=checkpoint_scheduler, |
855 | placeholder_token=args.placeholder_token, | 856 | new_tokens=new_tokens, |
856 | placeholder_token_id=placeholder_token_id, | ||
857 | output_dir=basepath, | 857 | output_dir=basepath, |
858 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
859 | sample_batch_size=args.sample_batch_size, | 859 | sample_batch_size=args.sample_batch_size, |