summaryrefslogtreecommitdiffstats
path: root/trainer/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'trainer/ti.py')
-rw-r--r--trainer/ti.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/trainer/ti.py b/trainer/ti.py
new file mode 100644
index 0000000..15cf747
--- /dev/null
+++ b/trainer/ti.py
@@ -0,0 +1,164 @@
1from contextlib import contextmanager, nullcontext
2
3import torch
4
5from slugify import slugify
6
7from diffusers import UNet2DConditionModel
8from transformers import CLIPTextModel
9
10from trainer.base import TrainingStrategy, Checkpointer
11from training.util import EMAModel
12
13
14class TextualInversionCheckpointer(Checkpointer):
15 def __init__(
16 self,
17 ema_embeddings: EMAModel,
18 *args,
19 **kwargs,
20 ):
21 super().__init__(*args, **kwargs)
22
23 self.ema_embeddings = ema_embeddings
24
25 @torch.no_grad()
26 def checkpoint(self, step, postfix):
27 print(f"Saving checkpoint for step {step}...")
28
29 checkpoints_path = self.output_dir.joinpath("checkpoints")
30 checkpoints_path.mkdir(parents=True, exist_ok=True)
31
32 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
33
34 ema_context = self.ema_embeddings.apply_temporary(
35 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
36 ) if self.ema_embeddings is not None else nullcontext()
37
38 with ema_context:
39 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
40 text_encoder.text_model.embeddings.save_embed(
41 ids,
42 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
43 )
44
45 @torch.inference_mode()
46 def save_samples(self, step):
47 ema_context = self.ema_embeddings.apply_temporary(
48 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
49 ) if self.ema_embeddings is not None else nullcontext()
50
51 with ema_context:
52 super().save_samples(step)
53
54
55class TextualInversionTrainingStrategy(TrainingStrategy):
56 def __init__(
57 self,
58 unet: UNet2DConditionModel,
59 text_encoder: CLIPTextModel,
60 placeholder_tokens: list[str],
61 placeholder_token_ids: list[list[int]],
62 learning_rate: float,
63 gradient_checkpointing: bool = False,
64 use_emb_decay: bool = False,
65 emb_decay_target: float = 0.4,
66 emb_decay_factor: float = 1,
67 emb_decay_start: float = 1e-4,
68 use_ema: bool = False,
69 ema_inv_gamma: float = 1.0,
70 ema_power: int = 1,
71 ema_max_decay: float = 0.9999,
72 *args,
73 **kwargs,
74 ):
75 super().__init__(
76 unet=unet,
77 text_encoder=text_encoder,
78 *args,
79 **kwargs
80 )
81
82 self.text_encoder = text_encoder
83 self.unet = unet
84
85 self.placeholder_tokens = placeholder_tokens
86 self.placeholder_token_ids = placeholder_token_ids
87
88 self.gradient_checkpointing = gradient_checkpointing
89
90 self.learning_rate = learning_rate
91 self.use_emb_decay = use_emb_decay
92 self.emb_decay_target = emb_decay_target
93 self.emb_decay_factor = emb_decay_factor
94 self.emb_decay_start = emb_decay_start
95
96 self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
97
98 self.ema_embeddings = None
99
100 if use_ema:
101 self.ema_embeddings = EMAModel(
102 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
103 inv_gamma=ema_inv_gamma,
104 power=ema_power,
105 max_value=ema_max_decay,
106 )
107
108 self.checkpointer = TextualInversionCheckpointer(
109 unet=unet,
110 text_encoder=text_encoder,
111 ema_embeddings=self.ema_embeddings,
112 *args,
113 **kwargs
114 )
115
116 @property
117 def main_model(self):
118 return self.text_encoder
119
120 @contextmanager
121 def on_train(self, epoch: int):
122 try:
123 if self.gradient_checkpointing:
124 self.unet.train()
125
126 with super().on_eval():
127 yield
128 finally:
129 pass
130
131 @contextmanager
132 def on_eval(self):
133 try:
134 if self.gradient_checkpointing:
135 self.unet.eval()
136
137 ema_context = self.ema_embeddings.apply_temporary(
138 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
139 ) if self.ema_embeddings is not None else nullcontext()
140
141 with ema_context, super().on_eval():
142 yield
143 finally:
144 pass
145
146 @torch.no_grad()
147 def on_after_optimize(self, lr: float):
148 if self.use_emb_decay:
149 self.text_encoder.text_model.embeddings.normalize(
150 self.emb_decay_target,
151 min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start))))
152 )
153
154 if self.ema_embeddings is not None:
155 self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters())
156
157 def on_log(self):
158 log = super().on_log()
159 added = {}
160
161 if self.ema_embeddings is not None:
162 added = {"ema_decay": self.ema_embeddings.decay}
163
164 return log.update(added)