summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-22 16:37:47 +0100
committerVolpeon <git@volpeon.ink>2022-12-22 16:37:47 +0100
commitfd691d762820863c5236a189a752ba4f985a961b (patch)
tree1f8db6c6629cdf7df552d7f24e0e7dd16c593b7f
parentSome LoRA fixes (still broken) (diff)
downloadtextual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.tar.gz
textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.tar.bz2
textual-inversion-diff-fd691d762820863c5236a189a752ba4f985a961b.zip
Improved Textual Inversion: Completely exclude untrained embeddings from training
-rw-r--r--train_ti.py24
-rw-r--r--training/ti.py70
2 files changed, 76 insertions, 18 deletions
diff --git a/train_ti.py b/train_ti.py
index 198cf37..bb51dc2 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -24,7 +24,8 @@ from common import load_text_embeddings, load_text_embedding
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args 27from training.ti import patch_trainable_embeddings
28from training.util import AverageMeter, CheckpointerBase, save_args
28from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
29 30
30logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -512,24 +513,14 @@ def main():
512 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 513 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
513 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) 514 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
514 515
515 original_token_embeds = token_embeds.clone().to(accelerator.device)
516
517 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 516 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
518 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 517 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
519 token_embeds[token_id] = embeddings 518 token_embeds[token_id] = embeddings
520 519
521 index_fixed_tokens = torch.arange(len(tokenizer)) 520 vae.requires_grad_(False)
522 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] 521 unet.requires_grad_(False)
523 522
524 # Freeze vae and unet 523 text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id)
525 freeze_params(vae.parameters())
526 freeze_params(unet.parameters())
527 # Freeze all parameters except for the token embeddings in text encoder
528 freeze_params(itertools.chain(
529 text_encoder.text_model.encoder.parameters(),
530 text_encoder.text_model.final_layer_norm.parameters(),
531 text_encoder.text_model.embeddings.position_embedding.parameters(),
532 ))
533 524
534 prompt_processor = PromptProcessor(tokenizer, text_encoder) 525 prompt_processor = PromptProcessor(tokenizer, text_encoder)
535 526
@@ -843,10 +834,7 @@ def main():
843 lr_scheduler.step() 834 lr_scheduler.step()
844 optimizer.zero_grad(set_to_none=True) 835 optimizer.zero_grad(set_to_none=True)
845 836
846 # Let's make sure we don't update any embedding weights besides the newly added token 837 text_embeddings.save()
847 with torch.no_grad():
848 text_encoder.get_input_embeddings(
849 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
850 838
851 avg_loss.update(loss.detach_(), bsz) 839 avg_loss.update(loss.detach_(), bsz)
852 avg_acc.update(acc.detach_(), bsz) 840 avg_acc.update(acc.detach_(), bsz)
diff --git a/training/ti.py b/training/ti.py
new file mode 100644
index 0000000..a5fd8e4
--- /dev/null
+++ b/training/ti.py
@@ -0,0 +1,70 @@
1from typing import Optional
2
3import torch
4import torch.nn as nn
5
6from transformers.models.clip import CLIPTextModel, CLIPTextConfig
7from transformers.models.clip.modeling_clip import CLIPTextEmbeddings
8
9
10def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]):
11 text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids)
12 text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight
13 text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight
14 text_encoder.text_model.embeddings = text_embeddings
15 return text_embeddings
16
17
18class TrainableEmbeddings(CLIPTextEmbeddings):
19 def __init__(self, config: CLIPTextConfig, new_ids: list[int]):
20 super().__init__(config)
21
22 self.token_embedding.requires_grad_(False)
23 self.position_embedding.requires_grad_(False)
24
25 self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))}
26
27 indices = torch.arange(self.token_embedding.num_embeddings)
28 self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))]
29
30 self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices])
31
32 def forward(
33 self,
34 input_ids: Optional[torch.LongTensor] = None,
35 position_ids: Optional[torch.LongTensor] = None,
36 inputs_embeds: Optional[torch.FloatTensor] = None,
37 ) -> torch.Tensor:
38 seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
39
40 if position_ids is None:
41 position_ids = self.position_ids[:, :seq_length]
42
43 if inputs_embeds is None:
44 mask = torch.isin(
45 input_ids,
46 self.train_indices.to(input_ids.device)
47 ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim)
48
49 trainable_input_ids = torch.tensor([
50 [
51 self.id_mapping[id] if id in self.id_mapping else 0
52 for id in batch
53 ]
54 for batch in input_ids
55 ], device=input_ids.device)
56
57 inputs_embeds = torch.where(
58 mask,
59 self.trainable_embedding(trainable_input_ids),
60 self.token_embedding(input_ids)
61 )
62
63 position_embeddings = self.position_embedding(position_ids)
64 embeddings = inputs_embeds + position_embeddings
65
66 return embeddings
67
68 @torch.no_grad()
69 def save(self):
70 self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data