diff options
-rw-r--r-- | models/clip/embeddings.py | 15 | ||||
-rw-r--r-- | models/sparse.py | 14 | ||||
-rw-r--r-- | train_ti.py | 24 | ||||
-rw-r--r-- | training/functional.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 22 |
5 files changed, 57 insertions, 22 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index a356434..63a141f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi | |||
37 | 37 | ||
38 | 38 | ||
39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 39 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: float = 1.0): | 40 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): |
41 | super().__init__(config) | 41 | super().__init__(config) |
42 | 42 | ||
43 | self.token_embedding = embeddings.token_embedding | 43 | self.token_embedding = embeddings.token_embedding |
@@ -49,7 +49,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
49 | device=self.token_embedding.weight.device, | 49 | device=self.token_embedding.weight.device, |
50 | dtype=self.token_embedding.weight.dtype, | 50 | dtype=self.token_embedding.weight.dtype, |
51 | ) | 51 | ) |
52 | self.alpha = alpha | ||
53 | 52 | ||
54 | def resize(self, size: int): | 53 | def resize(self, size: int): |
55 | self.token_override_embedding.resize(size) | 54 | self.token_override_embedding.resize(size) |
@@ -87,7 +86,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
87 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 86 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
88 | 87 | ||
89 | self.token_embedding.weight.data[token_ids] = initializer | 88 | self.token_embedding.weight.data[token_ids] = initializer |
90 | self.token_override_embedding.set(token_ids) | 89 | self.token_override_embedding.set(token_ids, initializer) |
91 | 90 | ||
92 | def load_embed(self, input_ids: list[int], filename: Path): | 91 | def load_embed(self, input_ids: list[int], filename: Path): |
93 | with safe_open(filename, framework="pt", device="cpu") as file: | 92 | with safe_open(filename, framework="pt", device="cpu") as file: |
@@ -101,8 +100,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
101 | embs, mask = self.token_override_embedding(input_ids) | 100 | embs, mask = self.token_override_embedding(input_ids) |
102 | if embs is not None: | 101 | if embs is not None: |
103 | input_ids = input_ids[mask] | 102 | input_ids = input_ids[mask] |
104 | self.token_embedding.weight.data[input_ids] += self.alpha * embs | 103 | self.token_embedding.weight.data[input_ids] = embs |
105 | self.token_override_embedding.unset(input_ids) | 104 | self.token_override_embedding.unset(input_ids) |
106 | 105 | ||
107 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 106 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
108 | if isinstance(input_ids, list): | 107 | if isinstance(input_ids, list): |
@@ -111,7 +110,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
111 | embs = self.token_embedding(input_ids) | 110 | embs = self.token_embedding(input_ids) |
112 | embs_override, mask = self.token_override_embedding(input_ids) | 111 | embs_override, mask = self.token_override_embedding(input_ids) |
113 | if embs_override is not None: | 112 | if embs_override is not None: |
114 | embs[mask] += self.alpha * embs_override | 113 | embs[mask] = embs_override |
115 | 114 | ||
116 | return embs | 115 | return embs |
117 | 116 | ||
@@ -135,7 +134,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
135 | return embeddings | 134 | return embeddings |
136 | 135 | ||
137 | 136 | ||
138 | def patch_managed_embeddings(text_encoder: CLIPTextModel, alpha: float = 1.0) -> ManagedCLIPTextEmbeddings: | 137 | def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: |
139 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, alpha) | 138 | text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) |
140 | text_encoder.text_model.embeddings = text_embeddings | 139 | text_encoder.text_model.embeddings = text_embeddings |
141 | return text_embeddings | 140 | return text_embeddings |
diff --git a/models/sparse.py b/models/sparse.py index 0b15454..8910316 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -13,10 +13,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
13 | self.params = nn.ParameterList() | 13 | self.params = nn.ParameterList() |
14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) | 14 | self.mapping = torch.zeros(0, device=device, dtype=torch.long) |
15 | 15 | ||
16 | def forward(self, input_ids: Optional[torch.LongTensor] = None): | 16 | def forward(self, input_ids: torch.LongTensor): |
17 | if input_ids is None: | ||
18 | input_ids = torch.arange(self.mapping.shape[0]) | ||
19 | |||
20 | ids = self.mapping[input_ids.to(self.mapping.device)] | 17 | ids = self.mapping[input_ids.to(self.mapping.device)] |
21 | mask = ~(ids == -1) | 18 | mask = ~(ids == -1) |
22 | 19 | ||
@@ -43,6 +40,12 @@ class PseudoSparseEmbedding(nn.Module): | |||
43 | else: | 40 | else: |
44 | return [self.set(id) for id in input_ids] | 41 | return [self.set(id) for id in input_ids] |
45 | 42 | ||
43 | if tensor is None: | ||
44 | tensor = torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
45 | |||
46 | if tensor.shape[-1] != self.embedding_dim: | ||
47 | raise ValueError(f"Expected tensor of shape [..., {self.embedding_dim}], but got [..., {tensor.shape[-1]}]") | ||
48 | |||
46 | id = self.mapping[input_ids] | 49 | id = self.mapping[input_ids] |
47 | 50 | ||
48 | if id == -1: | 51 | if id == -1: |
@@ -50,8 +53,7 @@ class PseudoSparseEmbedding(nn.Module): | |||
50 | self.mapping[input_ids] = id | 53 | self.mapping[input_ids] = id |
51 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) | 54 | self.params.append(torch.zeros(self.embedding_dim, device=self.mapping.device, dtype=self.dtype)) |
52 | 55 | ||
53 | self.params[id] = tensor if tensor is not None else torch.zeros( | 56 | self.params[id] = tensor |
54 | self.embedding_dim, device=self.mapping.device, dtype=self.dtype) | ||
55 | 57 | ||
56 | def unset(self, input_ids: torch.LongTensor): | 58 | def unset(self, input_ids: torch.LongTensor): |
57 | self.mapping[input_ids] = -1 | 59 | self.mapping[input_ids] = -1 |
diff --git a/train_ti.py b/train_ti.py index a9a2333..4366c9e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -353,7 +353,7 @@ def parse_args(): | |||
353 | parser.add_argument( | 353 | parser.add_argument( |
354 | "--adam_weight_decay", | 354 | "--adam_weight_decay", |
355 | type=float, | 355 | type=float, |
356 | default=1e-2, | 356 | default=0, |
357 | help="Weight decay to use." | 357 | help="Weight decay to use." |
358 | ) | 358 | ) |
359 | parser.add_argument( | 359 | parser.add_argument( |
@@ -451,10 +451,21 @@ def parse_args(): | |||
451 | help="The weight of prior preservation loss." | 451 | help="The weight of prior preservation loss." |
452 | ) | 452 | ) |
453 | parser.add_argument( | 453 | parser.add_argument( |
454 | "--emb_alpha", | 454 | "--use_emb_decay", |
455 | default=1.0, | 455 | action="store_true", |
456 | help="Whether to use embedding decay." | ||
457 | ) | ||
458 | parser.add_argument( | ||
459 | "--emb_decay_target", | ||
460 | default=0.4, | ||
461 | type=float, | ||
462 | help="Embedding decay target." | ||
463 | ) | ||
464 | parser.add_argument( | ||
465 | "--emb_decay", | ||
466 | default=1e+2, | ||
456 | type=float, | 467 | type=float, |
457 | help="Embedding alpha." | 468 | help="Embedding decay factor." |
458 | ) | 469 | ) |
459 | parser.add_argument( | 470 | parser.add_argument( |
460 | "--noise_timesteps", | 471 | "--noise_timesteps", |
@@ -600,7 +611,7 @@ def main(): | |||
600 | save_args(output_dir, args) | 611 | save_args(output_dir, args) |
601 | 612 | ||
602 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 613 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
603 | args.pretrained_model_name_or_path, args.emb_alpha) | 614 | args.pretrained_model_name_or_path) |
604 | 615 | ||
605 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 616 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
606 | tokenizer.set_dropout(args.vector_dropout) | 617 | tokenizer.set_dropout(args.vector_dropout) |
@@ -744,6 +755,9 @@ def main(): | |||
744 | tokenizer=tokenizer, | 755 | tokenizer=tokenizer, |
745 | sample_scheduler=sample_scheduler, | 756 | sample_scheduler=sample_scheduler, |
746 | checkpoint_output_dir=checkpoint_output_dir, | 757 | checkpoint_output_dir=checkpoint_output_dir, |
758 | use_emb_decay=args.use_emb_decay, | ||
759 | emb_decay_target=args.emb_decay_target, | ||
760 | emb_decay=args.emb_decay, | ||
747 | use_ema=args.use_ema, | 761 | use_ema=args.use_ema, |
748 | ema_inv_gamma=args.ema_inv_gamma, | 762 | ema_inv_gamma=args.ema_inv_gamma, |
749 | ema_power=args.ema_power, | 763 | ema_power=args.ema_power, |
diff --git a/training/functional.py b/training/functional.py index 1d8e2ee..96ecbc1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -73,7 +73,7 @@ def make_grid(images, rows, cols): | |||
73 | return grid | 73 | return grid |
74 | 74 | ||
75 | 75 | ||
76 | def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): | 76 | def get_models(pretrained_model_name_or_path: str): |
77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
@@ -82,7 +82,7 @@ def get_models(pretrained_model_name_or_path: str, emb_alpha: float = 1.0): | |||
82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
83 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
84 | 84 | ||
85 | embeddings = patch_managed_embeddings(text_encoder, emb_alpha) | 85 | embeddings = patch_managed_embeddings(text_encoder) |
86 | 86 | ||
87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 87 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings |
88 | 88 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 95128da..9df160a 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -31,6 +31,9 @@ def textual_inversion_strategy_callbacks( | |||
31 | seed: int, | 31 | seed: int, |
32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 33 | placeholder_token_ids: list[list[int]], |
34 | use_emb_decay: bool = False, | ||
35 | emb_decay_target: float = 0.4, | ||
36 | emb_decay: float = 1e-2, | ||
34 | use_ema: bool = False, | 37 | use_ema: bool = False, |
35 | ema_inv_gamma: float = 1.0, | 38 | ema_inv_gamma: float = 1.0, |
36 | ema_power: int = 1, | 39 | ema_power: int = 1, |
@@ -102,10 +105,26 @@ def textual_inversion_strategy_callbacks( | |||
102 | yield | 105 | yield |
103 | 106 | ||
104 | @torch.no_grad() | 107 | @torch.no_grad() |
105 | def on_after_optimize(zero_ids, lr: float): | 108 | def on_before_optimize(lr: float, epoch: int): |
109 | if use_emb_decay: | ||
110 | return torch.stack([ | ||
111 | p | ||
112 | for p in text_encoder.text_model.embeddings.token_override_embedding.params | ||
113 | if p.grad is not None | ||
114 | ]) | ||
115 | |||
116 | @torch.no_grad() | ||
117 | def on_after_optimize(w, lr: float): | ||
106 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
107 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) |
108 | 120 | ||
121 | if use_emb_decay: | ||
122 | lambda_ = emb_decay * lr | ||
123 | |||
124 | if lambda_ != 0: | ||
125 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
126 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
127 | |||
109 | def on_log(): | 128 | def on_log(): |
110 | if ema_embeddings is not None: | 129 | if ema_embeddings is not None: |
111 | return {"ema_decay": ema_embeddings.decay} | 130 | return {"ema_decay": ema_embeddings.decay} |
@@ -149,6 +168,7 @@ def textual_inversion_strategy_callbacks( | |||
149 | on_accum_model=on_accum_model, | 168 | on_accum_model=on_accum_model, |
150 | on_train=on_train, | 169 | on_train=on_train, |
151 | on_eval=on_eval, | 170 | on_eval=on_eval, |
171 | on_before_optimize=on_before_optimize, | ||
152 | on_after_optimize=on_after_optimize, | 172 | on_after_optimize=on_after_optimize, |
153 | on_log=on_log, | 173 | on_log=on_log, |
154 | on_checkpoint=on_checkpoint, | 174 | on_checkpoint=on_checkpoint, |