diff options
| -rw-r--r-- | infer.py | 22 | ||||
| -rw-r--r-- | train_lora.py | 13 | ||||
| -rw-r--r-- | train_ti.py | 13 | ||||
| -rw-r--r-- | training/functional.py | 58 | ||||
| -rw-r--r-- | training/strategy/lora.py | 8 | ||||
| -rw-r--r-- | training/util.py | 22 | ||||
| -rw-r--r-- | util/ti.py | 24 |
7 files changed, 119 insertions, 41 deletions
| @@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings | |||
| 35 | from models.clip.tokenizer import MultiCLIPTokenizer | 35 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 37 | from util.files import load_config, load_embeddings_from_dir | 37 | from util.files import load_config, load_embeddings_from_dir |
| 38 | from util.ti import load_embeddings | ||
| 38 | 39 | ||
| 39 | 40 | ||
| 40 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}): | |||
| 229 | json.dump(info, f, indent=4) | 230 | json.dump(info, f, indent=4) |
| 230 | 231 | ||
| 231 | 232 | ||
| 232 | def load_embeddings(pipeline, embeddings_dir): | 233 | def load_embeddings_dir(pipeline, embeddings_dir): |
| 233 | added_tokens, added_ids = load_embeddings_from_dir( | 234 | added_tokens, added_ids = load_embeddings_from_dir( |
| 234 | pipeline.tokenizer, | 235 | pipeline.tokenizer, |
| 235 | pipeline.text_encoder.text_model.embeddings, | 236 | pipeline.text_encoder.text_model.embeddings, |
| @@ -258,6 +259,9 @@ def load_lora(pipeline, path): | |||
| 258 | text_encoder_lora_ds = { | 259 | text_encoder_lora_ds = { |
| 259 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k |
| 260 | } | 261 | } |
| 262 | ti_lora_ds = { | ||
| 263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | ||
| 264 | } | ||
| 261 | 265 | ||
| 262 | unet_config = LoraConfig(**lora_config["peft_config"]) | 266 | unet_config = LoraConfig(**lora_config["peft_config"]) |
| 263 | pipeline.unet = LoraModel(unet_config, pipeline.unet) | 267 | pipeline.unet = LoraModel(unet_config, pipeline.unet) |
| @@ -268,6 +272,18 @@ def load_lora(pipeline, path): | |||
| 268 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) | 272 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) |
| 269 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) | 273 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) |
| 270 | 274 | ||
| 275 | tokens = [k for k, _ in ti_lora_ds] | ||
| 276 | token_embeddings = [v for _, v in ti_lora_ds] | ||
| 277 | |||
| 278 | added_tokens, added_ids = load_embeddings( | ||
| 279 | tokenizer=pipeline.tokenizer, | ||
| 280 | embeddings=pipeline.text_encoder.text_model.embeddings, | ||
| 281 | tokens=tokens, | ||
| 282 | token_embeddings=token_embeddings, | ||
| 283 | ) | ||
| 284 | pipeline.text_encoder.text_model.embeddings.persist() | ||
| 285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | ||
| 286 | |||
| 271 | return | 287 | return |
| 272 | 288 | ||
| 273 | 289 | ||
| @@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd): | |||
| 435 | return True | 451 | return True |
| 436 | 452 | ||
| 437 | if elements[0] == 'reload_embeddings': | 453 | if elements[0] == 'reload_embeddings': |
| 438 | load_embeddings(self.pipeline, self.ti_embeddings_dir) | 454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
| 439 | return | 455 | return |
| 440 | 456 | ||
| 441 | try: | 457 | try: |
| @@ -475,7 +491,7 @@ def main(): | |||
| 475 | 491 | ||
| 476 | pipeline = create_pipeline(args.model, dtype) | 492 | pipeline = create_pipeline(args.model, dtype) |
| 477 | 493 | ||
| 478 | load_embeddings(pipeline, args.ti_embeddings_dir) | 494 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
| 479 | load_lora(pipeline, args.lora_embedding) | 495 | load_lora(pipeline, args.lora_embedding) |
| 480 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 481 | 497 | ||
diff --git a/train_lora.py b/train_lora.py index c197206..9cf17c7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| 24 | from training.strategy.lora import lora_strategy | 24 | from training.strategy.lora import lora_strategy |
| 25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
| 26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
| 27 | 27 | ||
| 28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| 29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
| @@ -1035,6 +1035,11 @@ def main(): | |||
| 1035 | lr_warmup_epochs = args.lr_warmup_epochs | 1035 | lr_warmup_epochs = args.lr_warmup_epochs |
| 1036 | lr_cycles = args.lr_cycles | 1036 | lr_cycles = args.lr_cycles |
| 1037 | 1037 | ||
| 1038 | avg_loss = AverageMeter() | ||
| 1039 | avg_acc = AverageMeter() | ||
| 1040 | avg_loss_val = AverageMeter() | ||
| 1041 | avg_acc_val = AverageMeter() | ||
| 1042 | |||
| 1038 | while True: | 1043 | while True: |
| 1039 | if len(auto_cycles) != 0: | 1044 | if len(auto_cycles) != 0: |
| 1040 | response = auto_cycles.pop(0) | 1045 | response = auto_cycles.pop(0) |
| @@ -1122,7 +1127,7 @@ def main(): | |||
| 1122 | warmup_epochs=lr_warmup_epochs, | 1127 | warmup_epochs=lr_warmup_epochs, |
| 1123 | ) | 1128 | ) |
| 1124 | 1129 | ||
| 1125 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1130 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" |
| 1126 | 1131 | ||
| 1127 | trainer( | 1132 | trainer( |
| 1128 | strategy=lora_strategy, | 1133 | strategy=lora_strategy, |
| @@ -1142,6 +1147,10 @@ def main(): | |||
| 1142 | sample_frequency=lora_sample_frequency, | 1147 | sample_frequency=lora_sample_frequency, |
| 1143 | offset_noise_strength=args.offset_noise_strength, | 1148 | offset_noise_strength=args.offset_noise_strength, |
| 1144 | no_val=args.valid_set_size == 0, | 1149 | no_val=args.valid_set_size == 0, |
| 1150 | avg_loss=avg_loss, | ||
| 1151 | avg_acc=avg_acc, | ||
| 1152 | avg_loss_val=avg_loss_val, | ||
| 1153 | avg_acc_val=avg_acc_val, | ||
| 1145 | ) | 1154 | ) |
| 1146 | 1155 | ||
| 1147 | training_iter += 1 | 1156 | training_iter += 1 |
diff --git a/train_ti.py b/train_ti.py index d1e5467..fce4a5e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| 24 | from training.strategy.ti import textual_inversion_strategy | 24 | from training.strategy.ti import textual_inversion_strategy |
| 25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
| 26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
| 27 | 27 | ||
| 28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
| 29 | 29 | ||
| @@ -920,6 +920,11 @@ def main(): | |||
| 920 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
| 921 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
| 922 | 922 | ||
| 923 | avg_loss = AverageMeter() | ||
| 924 | avg_acc = AverageMeter() | ||
| 925 | avg_loss_val = AverageMeter() | ||
| 926 | avg_acc_val = AverageMeter() | ||
| 927 | |||
| 923 | while True: | 928 | while True: |
| 924 | if len(auto_cycles) != 0: | 929 | if len(auto_cycles) != 0: |
| 925 | response = auto_cycles.pop(0) | 930 | response = auto_cycles.pop(0) |
| @@ -977,7 +982,7 @@ def main(): | |||
| 977 | mid_point=args.lr_mid_point, | 982 | mid_point=args.lr_mid_point, |
| 978 | ) | 983 | ) |
| 979 | 984 | ||
| 980 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" | 985 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" |
| 981 | 986 | ||
| 982 | trainer( | 987 | trainer( |
| 983 | train_dataloader=datamodule.train_dataloader, | 988 | train_dataloader=datamodule.train_dataloader, |
| @@ -994,6 +999,10 @@ def main(): | |||
| 994 | sample_frequency=sample_frequency, | 999 | sample_frequency=sample_frequency, |
| 995 | placeholder_tokens=placeholder_tokens, | 1000 | placeholder_tokens=placeholder_tokens, |
| 996 | placeholder_token_ids=placeholder_token_ids, | 1001 | placeholder_token_ids=placeholder_token_ids, |
| 1002 | avg_loss=avg_loss, | ||
| 1003 | avg_acc=avg_acc, | ||
| 1004 | avg_loss_val=avg_loss_val, | ||
| 1005 | avg_acc_val=avg_acc_val, | ||
| 997 | ) | 1006 | ) |
| 998 | 1007 | ||
| 999 | training_iter += 1 | 1008 | training_iter += 1 |
diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -461,6 +461,10 @@ def train_loop( | |||
| 461 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
| 462 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
| 463 | group_labels: list[str] = [], | 463 | group_labels: list[str] = [], |
| 464 | avg_loss: AverageMeter = AverageMeter(), | ||
| 465 | avg_acc: AverageMeter = AverageMeter(), | ||
| 466 | avg_loss_val: AverageMeter = AverageMeter(), | ||
| 467 | avg_acc_val: AverageMeter = AverageMeter(), | ||
| 464 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 465 | ): | 469 | ): |
| 466 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
| @@ -472,14 +476,8 @@ def train_loop( | |||
| 472 | global_step = 0 | 476 | global_step = 0 |
| 473 | cache = {} | 477 | cache = {} |
| 474 | 478 | ||
| 475 | avg_loss = AverageMeter() | 479 | best_acc = avg_acc.avg |
| 476 | avg_acc = AverageMeter() | 480 | best_acc_val = avg_acc_val.avg |
| 477 | |||
| 478 | avg_loss_val = AverageMeter() | ||
| 479 | avg_acc_val = AverageMeter() | ||
| 480 | |||
| 481 | best_acc = 0.0 | ||
| 482 | best_acc_val = 0.0 | ||
| 483 | 481 | ||
| 484 | local_progress_bar = tqdm( | 482 | local_progress_bar = tqdm( |
| 485 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -544,12 +542,12 @@ def train_loop( | |||
| 544 | 542 | ||
| 545 | accelerator.backward(loss) | 543 | accelerator.backward(loss) |
| 546 | 544 | ||
| 547 | avg_loss.update(loss.detach_(), bsz) | 545 | avg_loss.update(loss.item(), bsz) |
| 548 | avg_acc.update(acc.detach_(), bsz) | 546 | avg_acc.update(acc.item(), bsz) |
| 549 | 547 | ||
| 550 | logs = { | 548 | logs = { |
| 551 | "train/loss": avg_loss.avg.item(), | 549 | "train/loss": avg_loss.avg, |
| 552 | "train/acc": avg_acc.avg.item(), | 550 | "train/acc": avg_acc.avg, |
| 553 | "train/cur_loss": loss.item(), | 551 | "train/cur_loss": loss.item(), |
| 554 | "train/cur_acc": acc.item(), | 552 | "train/cur_acc": acc.item(), |
| 555 | } | 553 | } |
| @@ -603,47 +601,47 @@ def train_loop( | |||
| 603 | loss = loss.detach_() | 601 | loss = loss.detach_() |
| 604 | acc = acc.detach_() | 602 | acc = acc.detach_() |
| 605 | 603 | ||
| 606 | cur_loss_val.update(loss, bsz) | 604 | cur_loss_val.update(loss.item(), bsz) |
| 607 | cur_acc_val.update(acc, bsz) | 605 | cur_acc_val.update(acc.item(), bsz) |
| 608 | 606 | ||
| 609 | avg_loss_val.update(loss, bsz) | 607 | avg_loss_val.update(loss.item(), bsz) |
| 610 | avg_acc_val.update(acc, bsz) | 608 | avg_acc_val.update(acc.item(), bsz) |
| 611 | 609 | ||
| 612 | local_progress_bar.update(1) | 610 | local_progress_bar.update(1) |
| 613 | global_progress_bar.update(1) | 611 | global_progress_bar.update(1) |
| 614 | 612 | ||
| 615 | logs = { | 613 | logs = { |
| 616 | "val/loss": avg_loss_val.avg.item(), | 614 | "val/loss": avg_loss_val.avg, |
| 617 | "val/acc": avg_acc_val.avg.item(), | 615 | "val/acc": avg_acc_val.avg, |
| 618 | "val/cur_loss": loss.item(), | 616 | "val/cur_loss": loss.item(), |
| 619 | "val/cur_acc": acc.item(), | 617 | "val/cur_acc": acc.item(), |
| 620 | } | 618 | } |
| 621 | local_progress_bar.set_postfix(**logs) | 619 | local_progress_bar.set_postfix(**logs) |
| 622 | 620 | ||
| 623 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 621 | logs["val/cur_loss"] = cur_loss_val.avg |
| 624 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 622 | logs["val/cur_acc"] = cur_acc_val.avg |
| 625 | 623 | ||
| 626 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
| 627 | 625 | ||
| 628 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
| 629 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: |
| 630 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
| 631 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
| 632 | 630 | ||
| 633 | accelerator.print( | 631 | accelerator.print( |
| 634 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
| 635 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
| 636 | best_acc_val = avg_acc_val.avg.item() | 634 | best_acc_val = avg_acc_val.avg |
| 637 | else: | 635 | else: |
| 638 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
| 639 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 637 | if avg_acc.avg > best_acc and milestone_checkpoints: |
| 640 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
| 641 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
| 642 | 640 | ||
| 643 | accelerator.print( | 641 | accelerator.print( |
| 644 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
| 645 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
| 646 | best_acc = avg_acc.avg.item() | 644 | best_acc = avg_acc.avg |
| 647 | 645 | ||
| 648 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
| 649 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
| @@ -688,6 +686,10 @@ def train( | |||
| 688 | offset_noise_strength: float = 0.15, | 686 | offset_noise_strength: float = 0.15, |
| 689 | disc: Optional[ConvNeXtDiscriminator] = None, | 687 | disc: Optional[ConvNeXtDiscriminator] = None, |
| 690 | min_snr_gamma: int = 5, | 688 | min_snr_gamma: int = 5, |
| 689 | avg_loss: AverageMeter = AverageMeter(), | ||
| 690 | avg_acc: AverageMeter = AverageMeter(), | ||
| 691 | avg_loss_val: AverageMeter = AverageMeter(), | ||
| 692 | avg_acc_val: AverageMeter = AverageMeter(), | ||
| 691 | **kwargs, | 693 | **kwargs, |
| 692 | ): | 694 | ): |
| 693 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 695 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
| @@ -737,6 +739,10 @@ def train( | |||
| 737 | num_epochs=num_train_epochs, | 739 | num_epochs=num_train_epochs, |
| 738 | gradient_accumulation_steps=gradient_accumulation_steps, | 740 | gradient_accumulation_steps=gradient_accumulation_steps, |
| 739 | group_labels=group_labels, | 741 | group_labels=group_labels, |
| 742 | avg_loss=avg_loss, | ||
| 743 | avg_acc=avg_acc, | ||
| 744 | avg_loss_val=avg_loss_val, | ||
| 745 | avg_acc_val=avg_acc_val, | ||
| 740 | callbacks=callbacks, | 746 | callbacks=callbacks, |
| 741 | ) | 747 | ) |
| 742 | 748 | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -138,6 +138,14 @@ def lora_strategy_callbacks( | |||
| 138 | state_dict.update(text_encoder_state_dict) | 138 | state_dict.update(text_encoder_state_dict) |
| 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
| 140 | 140 | ||
| 141 | if len(placeholder_tokens) != 0: | ||
| 142 | ti_state_dict = { | ||
| 143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | ||
| 144 | for (token, ids) | ||
| 145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
| 146 | } | ||
| 147 | state_dict.update(ti_state_dict) | ||
| 148 | |||
| 141 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
| 142 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
| 143 | json.dump(lora_config, f) | 151 | json.dump(lora_config, f) |
diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): | |||
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | class AverageMeter: | 18 | class AverageMeter: |
| 19 | avg: Any | 19 | def __init__(self, inv_gamma=1.0, power=2 / 3): |
| 20 | 20 | self.inv_gamma = inv_gamma | |
| 21 | def __init__(self, name=None): | 21 | self.power = power |
| 22 | self.name = name | ||
| 23 | self.reset() | 22 | self.reset() |
| 24 | 23 | ||
| 25 | def reset(self): | 24 | def reset(self): |
| 26 | self.sum = self.count = self.avg = 0 | 25 | self.step = 0 |
| 26 | self.avg = 0 | ||
| 27 | |||
| 28 | def get_decay(self): | ||
| 29 | if self.step <= 0: | ||
| 30 | return 1 | ||
| 31 | |||
| 32 | return (self.step / self.inv_gamma) ** -self.power | ||
| 27 | 33 | ||
| 28 | def update(self, val, n=1): | 34 | def update(self, val, n=1): |
| 29 | self.sum += val * n | 35 | for _ in range(n): |
| 30 | self.count += n | 36 | self.step += n |
| 31 | self.avg = self.sum / self.count | 37 | self.avg += (val - self.avg) * self.get_decay() |
| 32 | 38 | ||
| 33 | 39 | ||
| 34 | class EMAModel(EMAModel_): | 40 | class EMAModel(EMAModel_): |
diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py | |||
| @@ -0,0 +1,24 @@ | |||
| 1 | from pathlib import Path | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
| 6 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 7 | |||
| 8 | |||
| 9 | def load_embeddings( | ||
| 10 | tokenizer: MultiCLIPTokenizer, | ||
| 11 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 12 | tokens: list[str], | ||
| 13 | token_embeddings: torch.FloatTensor, | ||
| 14 | ): | ||
| 15 | num_vectors = [embedding.shape[0] for embedding in token_embeddings] | ||
| 16 | |||
| 17 | token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) | ||
| 18 | |||
| 19 | embeddings.resize(len(tokenizer)) | ||
| 20 | |||
| 21 | for (new_id, embeds) in zip(token_ids, token_embeddings): | ||
| 22 | embeddings.add_embed(new_id, embeds) | ||
| 23 | |||
| 24 | return tokens, token_ids | ||
