summaryrefslogtreecommitdiffstats
path: root/trainer_old/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'trainer_old/ti.py')
-rw-r--r--trainer_old/ti.py168
1 files changed, 168 insertions, 0 deletions
diff --git a/trainer_old/ti.py b/trainer_old/ti.py
new file mode 100644
index 0000000..66393af
--- /dev/null
+++ b/trainer_old/ti.py
@@ -0,0 +1,168 @@
1from contextlib import contextmanager, nullcontext
2
3import torch
4
5from slugify import slugify
6
7from diffusers import UNet2DConditionModel
8from transformers import CLIPTextModel
9
10from trainer.base import TrainingStrategy, Checkpointer
11from training.util import EMAModel
12
13
14class TextualInversionCheckpointer(Checkpointer):
15 def __init__(
16 self,
17 ema_embeddings: EMAModel,
18 placeholder_tokens: list[str],
19 placeholder_token_ids: list[list[int]],
20 *args,
21 **kwargs,
22 ):
23 super().__init__(*args, **kwargs)
24
25 self.ema_embeddings = ema_embeddings
26 self.placeholder_tokens = placeholder_tokens
27 self.placeholder_token_ids = placeholder_token_ids
28
29 @torch.no_grad()
30 def checkpoint(self, step, postfix):
31 print(f"Saving checkpoint for step {step}...")
32
33 checkpoints_path = self.output_dir.joinpath("checkpoints")
34 checkpoints_path.mkdir(parents=True, exist_ok=True)
35
36 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
37
38 ema_context = self.ema_embeddings.apply_temporary(
39 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
40 ) if self.ema_embeddings is not None else nullcontext()
41
42 with ema_context:
43 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
44 text_encoder.text_model.embeddings.save_embed(
45 ids,
46 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
47 )
48
49 @torch.no_grad()
50 def save_samples(self, step):
51 ema_context = self.ema_embeddings.apply_temporary(
52 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
53 ) if self.ema_embeddings is not None else nullcontext()
54
55 with ema_context:
56 super().save_samples(step)
57
58
59class TextualInversionTrainingStrategy(TrainingStrategy):
60 def __init__(
61 self,
62 unet: UNet2DConditionModel,
63 text_encoder: CLIPTextModel,
64 placeholder_tokens: list[str],
65 placeholder_token_ids: list[list[int]],
66 learning_rate: float,
67 gradient_checkpointing: bool = False,
68 use_emb_decay: bool = False,
69 emb_decay_target: float = 0.4,
70 emb_decay_factor: float = 1,
71 emb_decay_start: float = 1e-4,
72 use_ema: bool = False,
73 ema_inv_gamma: float = 1.0,
74 ema_power: int = 1,
75 ema_max_decay: float = 0.9999,
76 *args,
77 **kwargs,
78 ):
79 super().__init__(
80 unet=unet,
81 text_encoder=text_encoder,
82 *args,
83 **kwargs
84 )
85
86 self.text_encoder = text_encoder
87 self.unet = unet
88
89 self.placeholder_tokens = placeholder_tokens
90 self.placeholder_token_ids = placeholder_token_ids
91
92 self.gradient_checkpointing = gradient_checkpointing
93
94 self.learning_rate = learning_rate
95 self.use_emb_decay = use_emb_decay
96 self.emb_decay_target = emb_decay_target
97 self.emb_decay_factor = emb_decay_factor
98 self.emb_decay_start = emb_decay_start
99
100 self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
101
102 self.ema_embeddings = None
103
104 if use_ema:
105 self.ema_embeddings = EMAModel(
106 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
107 inv_gamma=ema_inv_gamma,
108 power=ema_power,
109 max_value=ema_max_decay,
110 )
111
112 self.checkpointer = TextualInversionCheckpointer(
113 unet=unet,
114 text_encoder=text_encoder,
115 ema_embeddings=self.ema_embeddings,
116 *args,
117 **kwargs
118 )
119
120 @property
121 def main_model(self):
122 return self.text_encoder
123
124 @contextmanager
125 def on_train(self, epoch: int):
126 try:
127 if self.gradient_checkpointing:
128 self.unet.train()
129
130 with super().on_eval():
131 yield
132 finally:
133 pass
134
135 @contextmanager
136 def on_eval(self):
137 try:
138 if self.gradient_checkpointing:
139 self.unet.eval()
140
141 ema_context = self.ema_embeddings.apply_temporary(
142 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
143 ) if self.ema_embeddings is not None else nullcontext()
144
145 with ema_context, super().on_eval():
146 yield
147 finally:
148 pass
149
150 @torch.no_grad()
151 def on_after_optimize(self, lr: float):
152 if self.use_emb_decay:
153 self.text_encoder.text_model.embeddings.normalize(
154 self.emb_decay_target,
155 min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start))))
156 )
157
158 if self.ema_embeddings is not None:
159 self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters())
160
161 def on_log(self):
162 log = super().on_log()
163 added = {}
164
165 if self.ema_embeddings is not None:
166 added = {"ema_decay": self.ema_embeddings.decay}
167
168 return log.update(added)