summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--models/clip/embeddings.py19
-rw-r--r--train_ti.py30
-rw-r--r--training/strategy/ti.py76
3 files changed, 38 insertions, 87 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 88e0cc0..c9c788c 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -66,12 +66,20 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
66 self.initializer_factor = config.initializer_factor 66 self.initializer_factor = config.initializer_factor
67 67
68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank) 68 self.overlay = OverlayLinear(self.token_embedding.embedding_dim, self.token_embedding.embedding_dim, rank)
69 self.temp_token_embedding = nn.Embedding(
70 self.token_embedding.num_embeddings,
71 self.token_embedding.embedding_dim,
72 device=self.token_embedding.weight.device,
73 dtype=self.token_embedding.weight.dtype
74 )
75 self.temp_token_embedding.weight.data = self.token_embedding.weight.data.clone().detach()
69 self.temp_token_ids = torch.tensor([], dtype=torch.long) 76 self.temp_token_ids = torch.tensor([], dtype=torch.long)
70 77
71 def reset_overlay(self): 78 def reset_overlay(self):
72 self.overlay.reset() 79 self.overlay.reset()
73 80
74 def resize(self, size: int): 81 def resize(self, size: int):
82 self.temp_token_embedding = resize_embedding(self.temp_token_embedding, size, self.initializer_factor)
75 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 83 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
76 84
77 def add_embed( 85 def add_embed(
@@ -106,6 +114,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
106 token_ids = torch.tensor(token_ids, dtype=torch.long) 114 token_ids = torch.tensor(token_ids, dtype=torch.long)
107 115
108 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 116 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
117 self.temp_token_embedding.weight.data[token_ids] = initializer
109 self.token_embedding.weight.data[token_ids] = initializer 118 self.token_embedding.weight.data[token_ids] = initializer
110 119
111 def load_embed(self, input_ids: list[int], filename: Path): 120 def load_embed(self, input_ids: list[int], filename: Path):
@@ -116,9 +125,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
116 save_file({"embed": self.get_embed(input_ids)}, filename) 125 save_file({"embed": self.get_embed(input_ids)}, filename)
117 126
118 def persist(self): 127 def persist(self):
119 self.token_embedding.weight.data[self.temp_token_ids] += self.overlay( 128 embeds = self.temp_token_embedding.weight.data[self.temp_token_ids]
120 self.token_embedding.weight.data[self.temp_token_ids] 129 self.token_embedding.weight.data[self.temp_token_ids] = embeds + self.overlay(embeds)
121 )
122 self.overlay.reset() 130 self.overlay.reset()
123 self.temp_token_ids = torch.tensor([], dtype=torch.long) 131 self.temp_token_ids = torch.tensor([], dtype=torch.long)
124 132
@@ -127,8 +135,11 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
127 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 135 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long)
128 136
129 embeds = self.token_embedding(input_ids) 137 embeds = self.token_embedding(input_ids)
138
130 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device)) 139 mask = torch.isin(input_ids, self.temp_token_ids.to(input_ids.device))
131 embeds[mask] += self.overlay(embeds[mask]) 140
141 temp_embeds = self.temp_token_embedding(input_ids[mask])
142 embeds[mask] = temp_embeds + self.overlay(temp_embeds)
132 143
133 return embeds 144 return embeds
134 145
diff --git a/train_ti.py b/train_ti.py
index 0ce0056..26ac384 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -1,6 +1,7 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
4from functools import partial 5from functools import partial
5from pathlib import Path 6from pathlib import Path
6import math 7import math
@@ -307,26 +308,6 @@ def parse_args():
307 help="Minimum learning rate in the lr scheduler." 308 help="Minimum learning rate in the lr scheduler."
308 ) 309 )
309 parser.add_argument( 310 parser.add_argument(
310 "--use_ema",
311 action="store_true",
312 help="Whether to use EMA model."
313 )
314 parser.add_argument(
315 "--ema_inv_gamma",
316 type=float,
317 default=1.0
318 )
319 parser.add_argument(
320 "--ema_power",
321 type=float,
322 default=4/5
323 )
324 parser.add_argument(
325 "--ema_max_decay",
326 type=float,
327 default=0.9999
328 )
329 parser.add_argument(
330 "--optimizer", 311 "--optimizer",
331 type=str, 312 type=str,
332 default="dadan", 313 default="dadan",
@@ -715,10 +696,6 @@ def main():
715 sample_scheduler=sample_scheduler, 696 sample_scheduler=sample_scheduler,
716 checkpoint_output_dir=checkpoint_output_dir, 697 checkpoint_output_dir=checkpoint_output_dir,
717 gradient_checkpointing=args.gradient_checkpointing, 698 gradient_checkpointing=args.gradient_checkpointing,
718 use_ema=args.use_ema,
719 ema_inv_gamma=args.ema_inv_gamma,
720 ema_power=args.ema_power,
721 ema_max_decay=args.ema_max_decay,
722 sample_batch_size=args.sample_batch_size, 699 sample_batch_size=args.sample_batch_size,
723 sample_num_batches=args.sample_batches, 700 sample_num_batches=args.sample_batches,
724 sample_num_steps=args.sample_steps, 701 sample_num_steps=args.sample_steps,
@@ -780,7 +757,10 @@ def main():
780 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 757 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
781 758
782 optimizer = create_optimizer( 759 optimizer = create_optimizer(
783 text_encoder.text_model.embeddings.overlay.parameters(), 760 itertools.chain(
761 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
762 text_encoder.text_model.embeddings.overlay.parameters(),
763 ),
784 lr=args.learning_rate, 764 lr=args.learning_rate,
785 ) 765 )
786 766
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 19b8d25..33f5fb9 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,6 +1,6 @@
1from typing import Optional 1from typing import Optional
2from functools import partial 2from functools import partial
3from contextlib import contextmanager, nullcontext 3from contextlib import contextmanager
4from pathlib import Path 4from pathlib import Path
5 5
6import torch 6import torch
@@ -13,7 +13,6 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
13from slugify import slugify 13from slugify import slugify
14 14
15from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
16from training.util import EMAModel
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
18 17
19 18
@@ -32,10 +31,6 @@ def textual_inversion_strategy_callbacks(
32 placeholder_tokens: list[str], 31 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 32 placeholder_token_ids: list[list[int]],
34 gradient_checkpointing: bool = False, 33 gradient_checkpointing: bool = False,
35 use_ema: bool = False,
36 ema_inv_gamma: float = 1.0,
37 ema_power: int = 1,
38 ema_max_decay: float = 0.9999,
39 sample_batch_size: int = 1, 34 sample_batch_size: int = 1,
40 sample_num_batches: int = 1, 35 sample_num_batches: int = 1,
41 sample_num_steps: int = 20, 36 sample_num_steps: int = 20,
@@ -68,25 +63,6 @@ def textual_inversion_strategy_callbacks(
68 image_size=sample_image_size, 63 image_size=sample_image_size,
69 ) 64 )
70 65
71 if use_ema:
72 ema_embeddings = EMAModel(
73 text_encoder.text_model.embeddings.overlay.parameters(),
74 inv_gamma=ema_inv_gamma,
75 power=ema_power,
76 max_value=ema_max_decay,
77 )
78 ema_embeddings.to(accelerator.device)
79 else:
80 ema_embeddings = None
81
82 def ema_context():
83 if ema_embeddings is not None:
84 return ema_embeddings.apply_temporary(
85 text_encoder.text_model.embeddings.overlay.parameters()
86 )
87 else:
88 return nullcontext()
89
90 def on_accum_model(): 66 def on_accum_model():
91 return text_encoder.text_model.embeddings.overlay 67 return text_encoder.text_model.embeddings.overlay
92 68
@@ -98,50 +74,36 @@ def textual_inversion_strategy_callbacks(
98 @contextmanager 74 @contextmanager
99 def on_eval(): 75 def on_eval():
100 tokenizer.eval() 76 tokenizer.eval()
101 77 yield
102 with ema_context():
103 yield
104
105 @torch.no_grad()
106 def on_after_optimize(zero_ids, lr: float):
107 if ema_embeddings is not None:
108 ema_embeddings.step(text_encoder.text_model.embeddings.overlay.parameters())
109
110 def on_log():
111 if ema_embeddings is not None:
112 return {"ema_decay": ema_embeddings.decay}
113 return {}
114 78
115 @torch.no_grad() 79 @torch.no_grad()
116 def on_checkpoint(step, postfix): 80 def on_checkpoint(step, postfix):
117 print(f"Saving checkpoint for step {step}...") 81 print(f"Saving checkpoint for step {step}...")
118 82
119 with ema_context(): 83 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
120 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 84 text_encoder.text_model.embeddings.save_embed(
121 text_encoder.text_model.embeddings.save_embed( 85 ids,
122 ids, 86 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
123 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 87 )
124 )
125 88
126 @torch.no_grad() 89 @torch.no_grad()
127 def on_sample(step): 90 def on_sample(step):
128 with ema_context(): 91 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
129 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 92 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
130 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
131 93
132 orig_unet_dtype = unet_.dtype 94 orig_unet_dtype = unet_.dtype
133 orig_text_encoder_dtype = text_encoder_.dtype 95 orig_text_encoder_dtype = text_encoder_.dtype
134 96
135 unet_.to(dtype=weight_dtype) 97 unet_.to(dtype=weight_dtype)
136 text_encoder_.to(dtype=weight_dtype) 98 text_encoder_.to(dtype=weight_dtype)
137 99
138 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 100 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
139 101
140 unet_.to(dtype=orig_unet_dtype) 102 unet_.to(dtype=orig_unet_dtype)
141 text_encoder_.to(dtype=orig_text_encoder_dtype) 103 text_encoder_.to(dtype=orig_text_encoder_dtype)
142 104
143 del unet_ 105 del unet_
144 del text_encoder_ 106 del text_encoder_
145 107
146 if torch.cuda.is_available(): 108 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 109 torch.cuda.empty_cache()
@@ -150,8 +112,6 @@ def textual_inversion_strategy_callbacks(
150 on_accum_model=on_accum_model, 112 on_accum_model=on_accum_model,
151 on_train=on_train, 113 on_train=on_train,
152 on_eval=on_eval, 114 on_eval=on_eval,
153 on_after_optimize=on_after_optimize,
154 on_log=on_log,
155 on_checkpoint=on_checkpoint, 115 on_checkpoint=on_checkpoint,
156 on_sample=on_sample, 116 on_sample=on_sample,
157 ) 117 )