diff options
Diffstat (limited to 'trainer_old/ti.py')
-rw-r--r-- | trainer_old/ti.py | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/trainer_old/ti.py b/trainer_old/ti.py new file mode 100644 index 0000000..66393af --- /dev/null +++ b/trainer_old/ti.py | |||
@@ -0,0 +1,168 @@ | |||
1 | from contextlib import contextmanager, nullcontext | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from slugify import slugify | ||
6 | |||
7 | from diffusers import UNet2DConditionModel | ||
8 | from transformers import CLIPTextModel | ||
9 | |||
10 | from trainer.base import TrainingStrategy, Checkpointer | ||
11 | from training.util import EMAModel | ||
12 | |||
13 | |||
14 | class TextualInversionCheckpointer(Checkpointer): | ||
15 | def __init__( | ||
16 | self, | ||
17 | ema_embeddings: EMAModel, | ||
18 | placeholder_tokens: list[str], | ||
19 | placeholder_token_ids: list[list[int]], | ||
20 | *args, | ||
21 | **kwargs, | ||
22 | ): | ||
23 | super().__init__(*args, **kwargs) | ||
24 | |||
25 | self.ema_embeddings = ema_embeddings | ||
26 | self.placeholder_tokens = placeholder_tokens | ||
27 | self.placeholder_token_ids = placeholder_token_ids | ||
28 | |||
29 | @torch.no_grad() | ||
30 | def checkpoint(self, step, postfix): | ||
31 | print(f"Saving checkpoint for step {step}...") | ||
32 | |||
33 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
34 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
35 | |||
36 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
37 | |||
38 | ema_context = self.ema_embeddings.apply_temporary( | ||
39 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
40 | ) if self.ema_embeddings is not None else nullcontext() | ||
41 | |||
42 | with ema_context: | ||
43 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): | ||
44 | text_encoder.text_model.embeddings.save_embed( | ||
45 | ids, | ||
46 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | ||
47 | ) | ||
48 | |||
49 | @torch.no_grad() | ||
50 | def save_samples(self, step): | ||
51 | ema_context = self.ema_embeddings.apply_temporary( | ||
52 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
53 | ) if self.ema_embeddings is not None else nullcontext() | ||
54 | |||
55 | with ema_context: | ||
56 | super().save_samples(step) | ||
57 | |||
58 | |||
59 | class TextualInversionTrainingStrategy(TrainingStrategy): | ||
60 | def __init__( | ||
61 | self, | ||
62 | unet: UNet2DConditionModel, | ||
63 | text_encoder: CLIPTextModel, | ||
64 | placeholder_tokens: list[str], | ||
65 | placeholder_token_ids: list[list[int]], | ||
66 | learning_rate: float, | ||
67 | gradient_checkpointing: bool = False, | ||
68 | use_emb_decay: bool = False, | ||
69 | emb_decay_target: float = 0.4, | ||
70 | emb_decay_factor: float = 1, | ||
71 | emb_decay_start: float = 1e-4, | ||
72 | use_ema: bool = False, | ||
73 | ema_inv_gamma: float = 1.0, | ||
74 | ema_power: int = 1, | ||
75 | ema_max_decay: float = 0.9999, | ||
76 | *args, | ||
77 | **kwargs, | ||
78 | ): | ||
79 | super().__init__( | ||
80 | unet=unet, | ||
81 | text_encoder=text_encoder, | ||
82 | *args, | ||
83 | **kwargs | ||
84 | ) | ||
85 | |||
86 | self.text_encoder = text_encoder | ||
87 | self.unet = unet | ||
88 | |||
89 | self.placeholder_tokens = placeholder_tokens | ||
90 | self.placeholder_token_ids = placeholder_token_ids | ||
91 | |||
92 | self.gradient_checkpointing = gradient_checkpointing | ||
93 | |||
94 | self.learning_rate = learning_rate | ||
95 | self.use_emb_decay = use_emb_decay | ||
96 | self.emb_decay_target = emb_decay_target | ||
97 | self.emb_decay_factor = emb_decay_factor | ||
98 | self.emb_decay_start = emb_decay_start | ||
99 | |||
100 | self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
101 | |||
102 | self.ema_embeddings = None | ||
103 | |||
104 | if use_ema: | ||
105 | self.ema_embeddings = EMAModel( | ||
106 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
107 | inv_gamma=ema_inv_gamma, | ||
108 | power=ema_power, | ||
109 | max_value=ema_max_decay, | ||
110 | ) | ||
111 | |||
112 | self.checkpointer = TextualInversionCheckpointer( | ||
113 | unet=unet, | ||
114 | text_encoder=text_encoder, | ||
115 | ema_embeddings=self.ema_embeddings, | ||
116 | *args, | ||
117 | **kwargs | ||
118 | ) | ||
119 | |||
120 | @property | ||
121 | def main_model(self): | ||
122 | return self.text_encoder | ||
123 | |||
124 | @contextmanager | ||
125 | def on_train(self, epoch: int): | ||
126 | try: | ||
127 | if self.gradient_checkpointing: | ||
128 | self.unet.train() | ||
129 | |||
130 | with super().on_eval(): | ||
131 | yield | ||
132 | finally: | ||
133 | pass | ||
134 | |||
135 | @contextmanager | ||
136 | def on_eval(self): | ||
137 | try: | ||
138 | if self.gradient_checkpointing: | ||
139 | self.unet.eval() | ||
140 | |||
141 | ema_context = self.ema_embeddings.apply_temporary( | ||
142 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
143 | ) if self.ema_embeddings is not None else nullcontext() | ||
144 | |||
145 | with ema_context, super().on_eval(): | ||
146 | yield | ||
147 | finally: | ||
148 | pass | ||
149 | |||
150 | @torch.no_grad() | ||
151 | def on_after_optimize(self, lr: float): | ||
152 | if self.use_emb_decay: | ||
153 | self.text_encoder.text_model.embeddings.normalize( | ||
154 | self.emb_decay_target, | ||
155 | min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start)))) | ||
156 | ) | ||
157 | |||
158 | if self.ema_embeddings is not None: | ||
159 | self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
160 | |||
161 | def on_log(self): | ||
162 | log = super().on_log() | ||
163 | added = {} | ||
164 | |||
165 | if self.ema_embeddings is not None: | ||
166 | added = {"ema_decay": self.ema_embeddings.decay} | ||
167 | |||
168 | return log.update(added) | ||