diff options
Diffstat (limited to 'trainer/ti.py')
-rw-r--r-- | trainer/ti.py | 164 |
1 files changed, 0 insertions, 164 deletions
diff --git a/trainer/ti.py b/trainer/ti.py deleted file mode 100644 index 388acd3..0000000 --- a/trainer/ti.py +++ /dev/null | |||
@@ -1,164 +0,0 @@ | |||
1 | from contextlib import contextmanager, nullcontext | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from slugify import slugify | ||
6 | |||
7 | from diffusers import UNet2DConditionModel | ||
8 | from transformers import CLIPTextModel | ||
9 | |||
10 | from trainer.base import TrainingStrategy, Checkpointer | ||
11 | from training.util import EMAModel | ||
12 | |||
13 | |||
14 | class TextualInversionCheckpointer(Checkpointer): | ||
15 | def __init__( | ||
16 | self, | ||
17 | ema_embeddings: EMAModel, | ||
18 | *args, | ||
19 | **kwargs, | ||
20 | ): | ||
21 | super().__init__(*args, **kwargs) | ||
22 | |||
23 | self.ema_embeddings = ema_embeddings | ||
24 | |||
25 | @torch.no_grad() | ||
26 | def checkpoint(self, step, postfix): | ||
27 | print(f"Saving checkpoint for step {step}...") | ||
28 | |||
29 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
30 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
31 | |||
32 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
33 | |||
34 | ema_context = self.ema_embeddings.apply_temporary( | ||
35 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
36 | ) if self.ema_embeddings is not None else nullcontext() | ||
37 | |||
38 | with ema_context: | ||
39 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): | ||
40 | text_encoder.text_model.embeddings.save_embed( | ||
41 | ids, | ||
42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | ||
43 | ) | ||
44 | |||
45 | @torch.no_grad() | ||
46 | def save_samples(self, step): | ||
47 | ema_context = self.ema_embeddings.apply_temporary( | ||
48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
49 | ) if self.ema_embeddings is not None else nullcontext() | ||
50 | |||
51 | with ema_context: | ||
52 | super().save_samples(step) | ||
53 | |||
54 | |||
55 | class TextualInversionTrainingStrategy(TrainingStrategy): | ||
56 | def __init__( | ||
57 | self, | ||
58 | unet: UNet2DConditionModel, | ||
59 | text_encoder: CLIPTextModel, | ||
60 | placeholder_tokens: list[str], | ||
61 | placeholder_token_ids: list[list[int]], | ||
62 | learning_rate: float, | ||
63 | gradient_checkpointing: bool = False, | ||
64 | use_emb_decay: bool = False, | ||
65 | emb_decay_target: float = 0.4, | ||
66 | emb_decay_factor: float = 1, | ||
67 | emb_decay_start: float = 1e-4, | ||
68 | use_ema: bool = False, | ||
69 | ema_inv_gamma: float = 1.0, | ||
70 | ema_power: int = 1, | ||
71 | ema_max_decay: float = 0.9999, | ||
72 | *args, | ||
73 | **kwargs, | ||
74 | ): | ||
75 | super().__init__( | ||
76 | unet=unet, | ||
77 | text_encoder=text_encoder, | ||
78 | *args, | ||
79 | **kwargs | ||
80 | ) | ||
81 | |||
82 | self.text_encoder = text_encoder | ||
83 | self.unet = unet | ||
84 | |||
85 | self.placeholder_tokens = placeholder_tokens | ||
86 | self.placeholder_token_ids = placeholder_token_ids | ||
87 | |||
88 | self.gradient_checkpointing = gradient_checkpointing | ||
89 | |||
90 | self.learning_rate = learning_rate | ||
91 | self.use_emb_decay = use_emb_decay | ||
92 | self.emb_decay_target = emb_decay_target | ||
93 | self.emb_decay_factor = emb_decay_factor | ||
94 | self.emb_decay_start = emb_decay_start | ||
95 | |||
96 | self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
97 | |||
98 | self.ema_embeddings = None | ||
99 | |||
100 | if use_ema: | ||
101 | self.ema_embeddings = EMAModel( | ||
102 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
103 | inv_gamma=ema_inv_gamma, | ||
104 | power=ema_power, | ||
105 | max_value=ema_max_decay, | ||
106 | ) | ||
107 | |||
108 | self.checkpointer = TextualInversionCheckpointer( | ||
109 | unet=unet, | ||
110 | text_encoder=text_encoder, | ||
111 | ema_embeddings=self.ema_embeddings, | ||
112 | *args, | ||
113 | **kwargs | ||
114 | ) | ||
115 | |||
116 | @property | ||
117 | def main_model(self): | ||
118 | return self.text_encoder | ||
119 | |||
120 | @contextmanager | ||
121 | def on_train(self, epoch: int): | ||
122 | try: | ||
123 | if self.gradient_checkpointing: | ||
124 | self.unet.train() | ||
125 | |||
126 | with super().on_eval(): | ||
127 | yield | ||
128 | finally: | ||
129 | pass | ||
130 | |||
131 | @contextmanager | ||
132 | def on_eval(self): | ||
133 | try: | ||
134 | if self.gradient_checkpointing: | ||
135 | self.unet.eval() | ||
136 | |||
137 | ema_context = self.ema_embeddings.apply_temporary( | ||
138 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
139 | ) if self.ema_embeddings is not None else nullcontext() | ||
140 | |||
141 | with ema_context, super().on_eval(): | ||
142 | yield | ||
143 | finally: | ||
144 | pass | ||
145 | |||
146 | @torch.no_grad() | ||
147 | def on_after_optimize(self, lr: float): | ||
148 | if self.use_emb_decay: | ||
149 | self.text_encoder.text_model.embeddings.normalize( | ||
150 | self.emb_decay_target, | ||
151 | min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start)))) | ||
152 | ) | ||
153 | |||
154 | if self.ema_embeddings is not None: | ||
155 | self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
156 | |||
157 | def on_log(self): | ||
158 | log = super().on_log() | ||
159 | added = {} | ||
160 | |||
161 | if self.ema_embeddings is not None: | ||
162 | added = {"ema_decay": self.ema_embeddings.decay} | ||
163 | |||
164 | return log.update(added) | ||