diff options
-rw-r--r-- | train_ti.py | 24 | ||||
-rw-r--r-- | training/ti.py | 70 |
2 files changed, 76 insertions, 18 deletions
diff --git a/train_ti.py b/train_ti.py index 198cf37..bb51dc2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -24,7 +24,8 @@ from common import load_text_embeddings, load_text_embedding | |||
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 27 | from training.ti import patch_trainable_embeddings |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | ||
28 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
29 | 30 | ||
30 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -512,24 +513,14 @@ def main(): | |||
512 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): | 513 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
513 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) | 514 | load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin")) |
514 | 515 | ||
515 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
516 | |||
517 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 516 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
518 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 517 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |
519 | token_embeds[token_id] = embeddings | 518 | token_embeds[token_id] = embeddings |
520 | 519 | ||
521 | index_fixed_tokens = torch.arange(len(tokenizer)) | 520 | vae.requires_grad_(False) |
522 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 521 | unet.requires_grad_(False) |
523 | 522 | ||
524 | # Freeze vae and unet | 523 | text_embeddings = patch_trainable_embeddings(text_encoder, placeholder_token_id) |
525 | freeze_params(vae.parameters()) | ||
526 | freeze_params(unet.parameters()) | ||
527 | # Freeze all parameters except for the token embeddings in text encoder | ||
528 | freeze_params(itertools.chain( | ||
529 | text_encoder.text_model.encoder.parameters(), | ||
530 | text_encoder.text_model.final_layer_norm.parameters(), | ||
531 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
532 | )) | ||
533 | 524 | ||
534 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 525 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
535 | 526 | ||
@@ -843,10 +834,7 @@ def main(): | |||
843 | lr_scheduler.step() | 834 | lr_scheduler.step() |
844 | optimizer.zero_grad(set_to_none=True) | 835 | optimizer.zero_grad(set_to_none=True) |
845 | 836 | ||
846 | # Let's make sure we don't update any embedding weights besides the newly added token | 837 | text_embeddings.save() |
847 | with torch.no_grad(): | ||
848 | text_encoder.get_input_embeddings( | ||
849 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
850 | 838 | ||
851 | avg_loss.update(loss.detach_(), bsz) | 839 | avg_loss.update(loss.detach_(), bsz) |
852 | avg_acc.update(acc.detach_(), bsz) | 840 | avg_acc.update(acc.detach_(), bsz) |
diff --git a/training/ti.py b/training/ti.py new file mode 100644 index 0000000..a5fd8e4 --- /dev/null +++ b/training/ti.py | |||
@@ -0,0 +1,70 @@ | |||
1 | from typing import Optional | ||
2 | |||
3 | import torch | ||
4 | import torch.nn as nn | ||
5 | |||
6 | from transformers.models.clip import CLIPTextModel, CLIPTextConfig | ||
7 | from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | ||
8 | |||
9 | |||
10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | ||
11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) | ||
12 | text_embeddings.token_embedding.weight = text_encoder.text_model.embeddings.token_embedding.weight | ||
13 | text_embeddings.position_embedding.weight = text_encoder.text_model.embeddings.position_embedding.weight | ||
14 | text_encoder.text_model.embeddings = text_embeddings | ||
15 | return text_embeddings | ||
16 | |||
17 | |||
18 | class TrainableEmbeddings(CLIPTextEmbeddings): | ||
19 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | ||
20 | super().__init__(config) | ||
21 | |||
22 | self.token_embedding.requires_grad_(False) | ||
23 | self.position_embedding.requires_grad_(False) | ||
24 | |||
25 | self.id_mapping = {new_ids[i]: i for i in range(len(new_ids))} | ||
26 | |||
27 | indices = torch.arange(self.token_embedding.num_embeddings) | ||
28 | self.train_indices = indices[torch.isin(indices, torch.tensor(new_ids))] | ||
29 | |||
30 | self.trainable_embedding = nn.Embedding.from_pretrained(self.token_embedding.weight[self.train_indices]) | ||
31 | |||
32 | def forward( | ||
33 | self, | ||
34 | input_ids: Optional[torch.LongTensor] = None, | ||
35 | position_ids: Optional[torch.LongTensor] = None, | ||
36 | inputs_embeds: Optional[torch.FloatTensor] = None, | ||
37 | ) -> torch.Tensor: | ||
38 | seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] | ||
39 | |||
40 | if position_ids is None: | ||
41 | position_ids = self.position_ids[:, :seq_length] | ||
42 | |||
43 | if inputs_embeds is None: | ||
44 | mask = torch.isin( | ||
45 | input_ids, | ||
46 | self.train_indices.to(input_ids.device) | ||
47 | ).unsqueeze(-1).expand(-1, -1, self.token_embedding.embedding_dim) | ||
48 | |||
49 | trainable_input_ids = torch.tensor([ | ||
50 | [ | ||
51 | self.id_mapping[id] if id in self.id_mapping else 0 | ||
52 | for id in batch | ||
53 | ] | ||
54 | for batch in input_ids | ||
55 | ], device=input_ids.device) | ||
56 | |||
57 | inputs_embeds = torch.where( | ||
58 | mask, | ||
59 | self.trainable_embedding(trainable_input_ids), | ||
60 | self.token_embedding(input_ids) | ||
61 | ) | ||
62 | |||
63 | position_embeddings = self.position_embedding(position_ids) | ||
64 | embeddings = inputs_embeds + position_embeddings | ||
65 | |||
66 | return embeddings | ||
67 | |||
68 | @torch.no_grad() | ||
69 | def save(self): | ||
70 | self.token_embedding.weight.data[self.train_indices] = self.trainable_embedding.weight.data | ||