summaryrefslogtreecommitdiffstats
path: root/training/modules/ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/modules/ti.py')
-rw-r--r--training/modules/ti.py284
1 files changed, 284 insertions, 0 deletions
diff --git a/training/modules/ti.py b/training/modules/ti.py
new file mode 100644
index 0000000..2db6f88
--- /dev/null
+++ b/training/modules/ti.py
@@ -0,0 +1,284 @@
1from typing import Literal
2from functools import partial
3from contextlib import contextmanager, nullcontext
4
5import torch
6
7from slugify import slugify
8
9from accelerate import Accelerator
10from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel
12
13from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
14from models.clip.tokenizer import MultiCLIPTokenizer
15
16from training.common import TrainingSetup, get_scheduler, train_loop, loss_step
17from training.util import EMAModel, CheckpointerBase
18
19
20class Checkpointer(CheckpointerBase):
21 def __init__(
22 self,
23 accelerator: Accelerator,
24 vae: AutoencoderKL,
25 unet: UNet2DConditionModel,
26 tokenizer: MultiCLIPTokenizer,
27 text_encoder: CLIPTextModel,
28 ema_embeddings: EMAModel,
29 weight_dtype: torch.dtype,
30 scheduler,
31 placeholder_token,
32 placeholder_token_ids,
33 *args,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.weight_dtype = weight_dtype
39 self.accelerator = accelerator
40 self.vae = vae
41 self.unet = unet
42 self.tokenizer = tokenizer
43 self.text_encoder = text_encoder
44 self.ema_embeddings = ema_embeddings
45 self.scheduler = scheduler
46 self.placeholder_token = placeholder_token
47 self.placeholder_token_ids = placeholder_token_ids
48
49 @torch.no_grad()
50 def checkpoint(self, step, postfix):
51 print("Saving checkpoint for step %d..." % step)
52
53 checkpoints_path = self.output_dir.joinpath("checkpoints")
54 checkpoints_path.mkdir(parents=True, exist_ok=True)
55
56 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
57
58 ema_context = nullcontext()
59 if self.ema_embeddings is not None:
60 ema_context = self.ema_embeddings.apply_temporary(
61 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
62
63 with ema_context:
64 for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids):
65 text_encoder.text_model.embeddings.save_embed(
66 ids,
67 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
68 )
69
70 del text_encoder
71
72 @torch.no_grad()
73 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
74 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
75
76 ema_context = nullcontext()
77 if self.ema_embeddings is not None:
78 ema_context = self.ema_embeddings.apply_temporary(
79 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
80
81 with ema_context:
82 orig_dtype = text_encoder.dtype
83 text_encoder.to(dtype=self.weight_dtype)
84
85 pipeline = VlpnStableDiffusion(
86 text_encoder=text_encoder,
87 vae=self.vae,
88 unet=self.unet,
89 tokenizer=self.tokenizer,
90 scheduler=self.scheduler,
91 ).to(self.accelerator.device)
92 pipeline.set_progress_bar_config(dynamic_ncols=True)
93
94 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
95
96 text_encoder.to(dtype=orig_dtype)
97
98 del text_encoder
99 del pipeline
100
101 if torch.cuda.is_available():
102 torch.cuda.empty_cache()
103
104
105def train_ti(
106 setup: TrainingSetup,
107 num_train_epochs: int = 100,
108 num_class_images: int = 0,
109 prior_loss_weight: float = 1.0,
110 use_ema: bool = False,
111 ema_inv_gamma: float = 1.0,
112 ema_power: float = 4/5,
113 ema_max_decay: float = .9999,
114 adam_beta1: float = 0.9,
115 adam_beta2: float = 0.999,
116 adam_weight_decay: float = 0,
117 adam_epsilon: float = 1e-08,
118 adam_amsgrad: bool = False,
119 lr_scheduler: Literal[
120 "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle"
121 ] = "one_cycle",
122 lr_min_lr: float = 0.04,
123 lr_warmup_func: Literal["linear", "cos"] = "cos",
124 lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos",
125 lr_warmup_exp: int = 1,
126 lr_annealing_exp: int = 1,
127 lr_cycles: int = 1,
128 lr_warmup_epochs: int = 10,
129 emb_decay_target: float = 0.4,
130 emb_decay_factor: float = 1,
131 emb_decay_start: float = 1e-4,
132 sample_image_size: int = 768,
133 sample_batch_size: int = 1,
134 sample_batches: int = 1,
135 sample_frequency: int = 10,
136 sample_steps: int = 20,
137 checkpoint_frequency: int = 50,
138 global_step_offset: int = 0,
139):
140 if use_ema:
141 ema_embeddings = EMAModel(
142 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
143 inv_gamma=ema_inv_gamma,
144 power=ema_power,
145 max_value=ema_max_decay,
146 )
147 else:
148 ema_embeddings = None
149
150 setup.text_encoder.requires_grad_(True)
151 setup.text_encoder.text_model.encoder.requires_grad_(False)
152 setup.text_encoder.text_model.final_layer_norm.requires_grad_(False)
153 setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
154 setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
155
156 # Initialize the optimizer
157 optimizer = setup.optimizer_class(
158 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
159 lr=setup.learning_rate,
160 betas=(adam_beta1, adam_beta2),
161 weight_decay=adam_weight_decay,
162 eps=adam_epsilon,
163 amsgrad=adam_amsgrad,
164 )
165
166 lr_scheduler = get_scheduler(
167 lr_scheduler,
168 optimizer=optimizer,
169 min_lr=lr_min_lr,
170 warmup_func=lr_warmup_func,
171 annealing_func=lr_annealing_func,
172 warmup_exp=lr_warmup_exp,
173 annealing_exp=lr_annealing_exp,
174 cycles=lr_cycles,
175 train_epochs=num_train_epochs,
176 warmup_epochs=lr_warmup_epochs,
177 num_training_steps_per_epoch=len(setup.train_dataloader),
178 gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps
179 )
180
181 text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare(
182 setup.text_encoder, optimizer, lr_scheduler
183 )
184
185 # Move vae and unet to device
186 setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype)
187 setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype)
188
189 if use_ema:
190 ema_embeddings.to(setup.accelerator.device)
191
192 setup.unet.train()
193
194 @contextmanager
195 def on_train(epoch: int):
196 try:
197 setup.tokenizer.train()
198 yield
199 finally:
200 pass
201
202 @contextmanager
203 def on_eval():
204 try:
205 setup.tokenizer.eval()
206
207 ema_context = nullcontext()
208 if use_ema:
209 ema_context = ema_embeddings.apply_temporary(
210 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
211
212 with ema_context:
213 yield
214 finally:
215 pass
216
217 @torch.no_grad()
218 def on_after_optimize(lr: float):
219 text_encoder.text_model.embeddings.normalize(
220 emb_decay_target,
221 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start))))
222 )
223
224 if use_ema:
225 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
226
227 def on_log():
228 if use_ema:
229 return {"ema_decay": ema_embeddings.decay}
230 return {}
231
232 loss_step_ = partial(
233 loss_step,
234 setup.vae,
235 setup.noise_scheduler,
236 setup.unet,
237 text_encoder,
238 num_class_images != 0,
239 prior_loss_weight,
240 setup.seed,
241 )
242
243 checkpointer = Checkpointer(
244 accelerator=setup.accelerator,
245 vae=setup.vae,
246 unet=setup.unet,
247 tokenizer=setup.tokenizer,
248 text_encoder=text_encoder,
249 ema_embeddings=ema_embeddings,
250 weight_dtype=setup.weight_dtype,
251 scheduler=setup.checkpoint_scheduler,
252 placeholder_token=setup.placeholder_token,
253 placeholder_token_ids=setup.placeholder_token_ids,
254 train_dataloader=setup.train_dataloader,
255 val_dataloader=setup.val_dataloader,
256 output_dir=setup.output_dir,
257 seed=setup.seed,
258 sample_image_size=sample_image_size,
259 sample_batch_size=sample_batch_size,
260 sample_batches=sample_batches
261 )
262
263 if setup.accelerator.is_main_process:
264 setup.accelerator.init_trackers("textual_inversion")
265
266 train_loop(
267 accelerator=setup.accelerator,
268 optimizer=optimizer,
269 lr_scheduler=lr_scheduler,
270 model=text_encoder,
271 checkpointer=checkpointer,
272 train_dataloader=setup.train_dataloader,
273 val_dataloader=setup.val_dataloader,
274 loss_step=loss_step_,
275 sample_frequency=sample_frequency,
276 sample_steps=sample_steps,
277 checkpoint_frequency=checkpoint_frequency,
278 global_step_offset=global_step_offset,
279 num_epochs=num_train_epochs,
280 on_log=on_log,
281 on_train=on_train,
282 on_after_optimize=on_after_optimize,
283 on_eval=on_eval
284 )