diff options
-rw-r--r-- | models/clip/embeddings.py | 19 | ||||
-rw-r--r-- | train_ti.py | 30 | ||||
-rw-r--r-- | training/strategy/ti.py | 76 |
3 files changed, 38 insertions, 87 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 88e0cc0..c9c788c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
66 | self.initializer_factor = config.initializer_factor | 66 | self.initializer_factor = config.initializer_factor |
67 | 67 | ||
68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) | 68 | self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) |
69 | self.temp_token_embedding = nn.Embedding( | ||
70 | self.token_embedding.num_embeddings, | ||
71 | self.token_embedding.embedding_dim, | ||
72 | device=self.token_embedding.weight.device, | ||
73 | dtype=self.token_embedding.weight.dtype | ||
74 | ) | ||
75 | self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach() | ||
69 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 76 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
70 | 77 | ||
71 | def reset_overlay(self): | 78 | def reset_overlay(self): |
72 | self.overlay.reset() | 79 | self.overlay.reset() |
73 | 80 | ||
74 | def resize(self, size: int): | 81 | def resize(self, size: int): |
82 | self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor) | ||
75 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) | 83 | self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) |
76 | 84 | ||
77 | def add_embed( | 85 | def add_embed( |
@@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
106 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 114 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
107 | 115 | ||
108 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 116 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
117 | self.temp_token_embedding.weight.data[token_ids] = initializer | ||
109 | self.token_embedding.weight.data[token_ids] = initializer | 118 | self.token_embedding.weight.data[token_ids] = initializer |
110 | 119 | ||
111 | def load_embed(self, input_ids: list[int], filename: Path): | 120 | def load_embed(self, input_ids: list[int], filename: Path): |
@@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
116 | save_file({"embed": self.get_embed(input_ids)}, filename) | 125 | save_file({"embed": self.get_embed(input_ids)}, filename) |
117 | 126 | ||
118 | def persist(self): | 127 | def persist(self): |
119 | self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( | 128 | embeds = self.temp_token_embedding.weight.data[self.temp_token_ids] |
120 | self.token_embedding.weight.data[self.temp_token_ids] | 129 | self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds) |
121 | ) | ||
122 | self.overlay.reset() | 130 | self.overlay.reset() |
123 | self.temp_token_ids = torch.tensor([], dtype=torch.long) | 131 | self.temp_token_ids = torch.tensor([], dtype=torch.long) |
124 | 132 | ||
@@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
127 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 135 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) |
128 | 136 | ||
129 | embeds = self.token_embedding(input_ids) | 137 | embeds = self.token_embedding(input_ids) |
138 | |||
130 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) | 139 | mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) |
131 | embeds[mask] += self.overlay(embeds[mask]) | 140 | |
141 | temp_embeds = self.temp_token_embedding(input_ids[mask]) | ||
142 | embeds[mask] = temp_embeds + self.overlay(temp_embeds) | ||
132 | 143 | ||
133 | return embeds | 144 | return embeds |
134 | 145 | ||
diff --git a/train_ti.py b/train_ti.py index 0ce0056..26ac384 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
4 | from functools import partial | 5 | from functools import partial |
5 | from pathlib import Path | 6 | from pathlib import Path |
6 | import math | 7 | import math |
@@ -307,26 +308,6 @@ def parse_args(): | |||
307 | help="Minimum learning rate in the lr scheduler." | 308 | help="Minimum learning rate in the lr scheduler." |
308 | ) | 309 | ) |
309 | parser.add_argument( | 310 | parser.add_argument( |
310 | "--use_ema", | ||
311 | action="store_true", | ||
312 | help="Whether to use EMA model." | ||
313 | ) | ||
314 | parser.add_argument( | ||
315 | "--ema_inv_gamma", | ||
316 | type=float, | ||
317 | default=1.0 | ||
318 | ) | ||
319 | parser.add_argument( | ||
320 | "--ema_power", | ||
321 | type=float, | ||
322 | default=4/5 | ||
323 | ) | ||
324 | parser.add_argument( | ||
325 | "--ema_max_decay", | ||
326 | type=float, | ||
327 | default=0.9999 | ||
328 | ) | ||
329 | parser.add_argument( | ||
330 | "--optimizer", | 311 | "--optimizer", |
331 | type=str, | 312 | type=str, |
332 | default="dadan", | 313 | default="dadan", |
@@ -715,10 +696,6 @@ def main(): | |||
715 | sample_scheduler=sample_scheduler, | 696 | sample_scheduler=sample_scheduler, |
716 | checkpoint_output_dir=checkpoint_output_dir, | 697 | checkpoint_output_dir=checkpoint_output_dir, |
717 | gradient_checkpointing=args.gradient_checkpointing, | 698 | gradient_checkpointing=args.gradient_checkpointing, |
718 | use_ema=args.use_ema, | ||
719 | ema_inv_gamma=args.ema_inv_gamma, | ||
720 | ema_power=args.ema_power, | ||
721 | ema_max_decay=args.ema_max_decay, | ||
722 | sample_batch_size=args.sample_batch_size, | 699 | sample_batch_size=args.sample_batch_size, |
723 | sample_num_batches=args.sample_batches, | 700 | sample_num_batches=args.sample_batches, |
724 | sample_num_steps=args.sample_steps, | 701 | sample_num_steps=args.sample_steps, |
@@ -780,7 +757,10 @@ def main(): | |||
780 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 757 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
781 | 758 | ||
782 | optimizer = create_optimizer( | 759 | optimizer = create_optimizer( |
783 | text_encoder.text_model.embeddings.overlay.parameters(), | 760 | itertools.chain( |
761 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
762 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
763 | ), | ||
784 | lr=args.learning_rate, | 764 | lr=args.learning_rate, |
785 | ) | 765 | ) |
786 | 766 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 19b8d25..33f5fb9 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -1,6 +1,6 @@ | |||
1 | from typing import Optional | 1 | from typing import Optional |
2 | from functools import partial | 2 | from functools import partial |
3 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | 5 | ||
6 | import torch | 6 | import torch |
@@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
13 | from slugify import slugify | 13 | from slugify import slugify |
14 | 14 | ||
15 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
16 | from training.util import EMAModel | ||
17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
18 | 17 | ||
19 | 18 | ||
@@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks( | |||
32 | placeholder_tokens: list[str], | 31 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 32 | placeholder_token_ids: list[list[int]], |
34 | gradient_checkpointing: bool = False, | 33 | gradient_checkpointing: bool = False, |
35 | use_ema: bool = False, | ||
36 | ema_inv_gamma: float = 1.0, | ||
37 | ema_power: int = 1, | ||
38 | ema_max_decay: float = 0.9999, | ||
39 | sample_batch_size: int = 1, | 34 | sample_batch_size: int = 1, |
40 | sample_num_batches: int = 1, | 35 | sample_num_batches: int = 1, |
41 | sample_num_steps: int = 20, | 36 | sample_num_steps: int = 20, |
@@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks( | |||
68 | image_size=sample_image_size, | 63 | image_size=sample_image_size, |
69 | ) | 64 | ) |
70 | 65 | ||
71 | if use_ema: | ||
72 | ema_embeddings = EMAModel( | ||
73 | text_encoder.text_model.embeddings.overlay.parameters(), | ||
74 | inv_gamma=ema_inv_gamma, | ||
75 | power=ema_power, | ||
76 | max_value=ema_max_decay, | ||
77 | ) | ||
78 | ema_embeddings.to(accelerator.device) | ||
79 | else: | ||
80 | ema_embeddings = None | ||
81 | |||
82 | def ema_context(): | ||
83 | if ema_embeddings is not None: | ||
84 | return ema_embeddings.apply_temporary( | ||
85 | text_encoder.text_model.embeddings.overlay.parameters() | ||
86 | ) | ||
87 | else: | ||
88 | return nullcontext() | ||
89 | |||
90 | def on_accum_model(): | 66 | def on_accum_model(): |
91 | return text_encoder.text_model.embeddings.overlay | 67 | return text_encoder.text_model.embeddings.overlay |
92 | 68 | ||
@@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks( | |||
98 | @contextmanager | 74 | @contextmanager |
99 | def on_eval(): | 75 | def on_eval(): |
100 | tokenizer.eval() | 76 | tokenizer.eval() |
101 | 77 | yield | |
102 | with ema_context(): | ||
103 | yield | ||
104 | |||
105 | @torch.no_grad() | ||
106 | def on_after_optimize(zero_ids, lr: float): | ||
107 | if ema_embeddings is not None: | ||
108 | ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters()) | ||
109 | |||
110 | def on_log(): | ||
111 | if ema_embeddings is not None: | ||
112 | return {"ema_decay": ema_embeddings.decay} | ||
113 | return {} | ||
114 | 78 | ||
115 | @torch.no_grad() | 79 | @torch.no_grad() |
116 | def on_checkpoint(step, postfix): | 80 | def on_checkpoint(step, postfix): |
117 | print(f"Saving checkpoint for step {step}...") | 81 | print(f"Saving checkpoint for step {step}...") |
118 | 82 | ||
119 | with ema_context(): | 83 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
120 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 84 | text_encoder.text_model.embeddings.save_embed( |
121 | text_encoder.text_model.embeddings.save_embed( | 85 | ids, |
122 | ids, | 86 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
123 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 87 | ) |
124 | ) | ||
125 | 88 | ||
126 | @torch.no_grad() | 89 | @torch.no_grad() |
127 | def on_sample(step): | 90 | def on_sample(step): |
128 | with ema_context(): | 91 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 92 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | ||
131 | 93 | ||
132 | orig_unet_dtype = unet_.dtype | 94 | orig_unet_dtype = unet_.dtype |
133 | orig_text_encoder_dtype = text_encoder_.dtype | 95 | orig_text_encoder_dtype = text_encoder_.dtype |
134 | 96 | ||
135 | unet_.to(dtype=weight_dtype) | 97 | unet_.to(dtype=weight_dtype) |
136 | text_encoder_.to(dtype=weight_dtype) | 98 | text_encoder_.to(dtype=weight_dtype) |
137 | 99 | ||
138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 100 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
139 | 101 | ||
140 | unet_.to(dtype=orig_unet_dtype) | 102 | unet_.to(dtype=orig_unet_dtype) |
141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 103 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
142 | 104 | ||
143 | del unet_ | 105 | del unet_ |
144 | del text_encoder_ | 106 | del text_encoder_ |
145 | 107 | ||
146 | if torch.cuda.is_available(): | 108 | if torch.cuda.is_available(): |
147 | torch.cuda.empty_cache() | 109 | torch.cuda.empty_cache() |
@@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks( | |||
150 | on_accum_model=on_accum_model, | 112 | on_accum_model=on_accum_model, |
151 | on_train=on_train, | 113 | on_train=on_train, |
152 | on_eval=on_eval, | 114 | on_eval=on_eval, |
153 | on_after_optimize=on_after_optimize, | ||
154 | on_log=on_log, | ||
155 | on_checkpoint=on_checkpoint, | 115 | on_checkpoint=on_checkpoint, |
156 | on_sample=on_sample, | 116 | on_sample=on_sample, |
157 | ) | 117 | ) |