diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 205 | ||||
-rw-r--r-- | training/util.py | 13 |
2 files changed, 208 insertions, 10 deletions
diff --git a/training/common.py b/training/common.py index 90cf910..842ac07 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -1,14 +1,30 @@ | |||
1 | import math | 1 | import math |
2 | from contextlib import _GeneratorContextManager, nullcontext | ||
3 | from typing import Callable, Any, Tuple, Union | ||
2 | 4 | ||
3 | import torch | 5 | import torch |
4 | import torch.nn.functional as F | 6 | import torch.nn.functional as F |
7 | from torch.utils.data import DataLoader | ||
5 | 8 | ||
9 | from accelerate import Accelerator | ||
10 | from transformers import CLIPTokenizer, CLIPTextModel | ||
6 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
7 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 12 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup |
8 | 13 | ||
9 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from tqdm.auto import tqdm |
10 | 15 | ||
16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
17 | from models.clip.util import get_extended_embeddings | ||
11 | from training.optimization import get_one_cycle_schedule | 18 | from training.optimization import get_one_cycle_schedule |
19 | from training.util import AverageMeter, CheckpointerBase | ||
20 | |||
21 | |||
22 | def noop(*args, **kwards): | ||
23 | pass | ||
24 | |||
25 | |||
26 | def noop_on_log(): | ||
27 | return {} | ||
12 | 28 | ||
13 | 29 | ||
14 | def get_scheduler( | 30 | def get_scheduler( |
@@ -22,10 +38,11 @@ def get_scheduler( | |||
22 | cycles: int, | 38 | cycles: int, |
23 | warmup_epochs: int, | 39 | warmup_epochs: int, |
24 | optimizer: torch.optim.Optimizer, | 40 | optimizer: torch.optim.Optimizer, |
25 | max_train_steps: int, | 41 | num_train_epochs: int, |
26 | num_update_steps_per_epoch: int, | 42 | num_update_steps_per_epoch: int, |
27 | gradient_accumulation_steps: int, | 43 | gradient_accumulation_steps: int, |
28 | ): | 44 | ): |
45 | num_train_steps = num_train_epochs * num_update_steps_per_epoch | ||
29 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps | 46 | warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps |
30 | 47 | ||
31 | if id == "one_cycle": | 48 | if id == "one_cycle": |
@@ -33,7 +50,7 @@ def get_scheduler( | |||
33 | 50 | ||
34 | lr_scheduler = get_one_cycle_schedule( | 51 | lr_scheduler = get_one_cycle_schedule( |
35 | optimizer=optimizer, | 52 | optimizer=optimizer, |
36 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 53 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
37 | warmup=warmup_func, | 54 | warmup=warmup_func, |
38 | annealing=annealing_func, | 55 | annealing=annealing_func, |
39 | warmup_exp=warmup_exp, | 56 | warmup_exp=warmup_exp, |
@@ -42,12 +59,12 @@ def get_scheduler( | |||
42 | ) | 59 | ) |
43 | elif id == "cosine_with_restarts": | 60 | elif id == "cosine_with_restarts": |
44 | cycles = cycles if cycles is not None else math.ceil( | 61 | cycles = cycles if cycles is not None else math.ceil( |
45 | math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) | 62 | math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch))) |
46 | 63 | ||
47 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 64 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
48 | optimizer=optimizer, | 65 | optimizer=optimizer, |
49 | num_warmup_steps=warmup_steps, | 66 | num_warmup_steps=warmup_steps, |
50 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 67 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
51 | num_cycles=cycles, | 68 | num_cycles=cycles, |
52 | ) | 69 | ) |
53 | else: | 70 | else: |
@@ -55,7 +72,7 @@ def get_scheduler( | |||
55 | id, | 72 | id, |
56 | optimizer=optimizer, | 73 | optimizer=optimizer, |
57 | num_warmup_steps=warmup_steps, | 74 | num_warmup_steps=warmup_steps, |
58 | num_training_steps=max_train_steps * gradient_accumulation_steps, | 75 | num_training_steps=num_train_steps * gradient_accumulation_steps, |
59 | ) | 76 | ) |
60 | 77 | ||
61 | return lr_scheduler | 78 | return lr_scheduler |
@@ -117,12 +134,12 @@ def loss_step( | |||
117 | vae: AutoencoderKL, | 134 | vae: AutoencoderKL, |
118 | noise_scheduler: DDPMScheduler, | 135 | noise_scheduler: DDPMScheduler, |
119 | unet: UNet2DConditionModel, | 136 | unet: UNet2DConditionModel, |
120 | prompt_processor, | 137 | text_encoder: CLIPTextModel, |
121 | num_class_images: int, | 138 | num_class_images: int, |
122 | prior_loss_weight: float, | 139 | prior_loss_weight: float, |
123 | seed: int, | 140 | seed: int, |
124 | step: int, | 141 | step: int, |
125 | batch, | 142 | batch: dict[str, Any], |
126 | eval: bool = False | 143 | eval: bool = False |
127 | ): | 144 | ): |
128 | # Convert images to latent space | 145 | # Convert images to latent space |
@@ -149,7 +166,8 @@ def loss_step( | |||
149 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 166 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
150 | 167 | ||
151 | # Get the text embedding for conditioning | 168 | # Get the text embedding for conditioning |
152 | encoder_hidden_states = prompt_processor.get_embeddings( | 169 | encoder_hidden_states = get_extended_embeddings( |
170 | text_encoder, | ||
153 | batch["input_ids"], | 171 | batch["input_ids"], |
154 | batch["attention_mask"] | 172 | batch["attention_mask"] |
155 | ) | 173 | ) |
@@ -185,3 +203,172 @@ def loss_step( | |||
185 | acc = (model_pred == target).float().mean() | 203 | acc = (model_pred == target).float().mean() |
186 | 204 | ||
187 | return loss, acc, bsz | 205 | return loss, acc, bsz |
206 | |||
207 | |||
208 | def train_loop( | ||
209 | accelerator: Accelerator, | ||
210 | optimizer: torch.optim.Optimizer, | ||
211 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
212 | model: torch.nn.Module, | ||
213 | checkpointer: CheckpointerBase, | ||
214 | train_dataloader: DataLoader, | ||
215 | val_dataloader: DataLoader, | ||
216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | ||
217 | sample_frequency: int = 10, | ||
218 | sample_steps: int = 20, | ||
219 | checkpoint_frequency: int = 50, | ||
220 | global_step_offset: int = 0, | ||
221 | gradient_accumulation_steps: int = 1, | ||
222 | num_epochs: int = 100, | ||
223 | on_log: Callable[[], dict[str, Any]] = noop_on_log, | ||
224 | on_train: Callable[[], _GeneratorContextManager] = nullcontext, | ||
225 | on_before_optimize: Callable[[], None] = noop, | ||
226 | on_after_optimize: Callable[[float], None] = noop, | ||
227 | on_eval: Callable[[], _GeneratorContextManager] = nullcontext | ||
228 | ): | ||
229 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | ||
230 | num_train_steps = num_epochs * num_update_steps_per_epoch | ||
231 | |||
232 | num_val_steps_per_epoch = len(val_dataloader) | ||
233 | num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch) | ||
234 | num_val_steps = num_val_steps_per_epoch * num_epochs | ||
235 | |||
236 | global_step = 0 | ||
237 | |||
238 | avg_loss = AverageMeter() | ||
239 | avg_acc = AverageMeter() | ||
240 | |||
241 | avg_loss_val = AverageMeter() | ||
242 | avg_acc_val = AverageMeter() | ||
243 | |||
244 | max_acc_val = 0.0 | ||
245 | |||
246 | local_progress_bar = tqdm( | ||
247 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | ||
248 | disable=not accelerator.is_local_main_process, | ||
249 | dynamic_ncols=True | ||
250 | ) | ||
251 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
252 | |||
253 | global_progress_bar = tqdm( | ||
254 | range(num_train_steps + num_val_steps), | ||
255 | disable=not accelerator.is_local_main_process, | ||
256 | dynamic_ncols=True | ||
257 | ) | ||
258 | global_progress_bar.set_description("Total progress") | ||
259 | |||
260 | try: | ||
261 | for epoch in range(num_epochs): | ||
262 | if accelerator.is_main_process: | ||
263 | if epoch % sample_frequency == 0: | ||
264 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | ||
265 | |||
266 | if epoch % checkpoint_frequency == 0 and epoch != 0: | ||
267 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
268 | |||
269 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
270 | local_progress_bar.reset() | ||
271 | |||
272 | model.train() | ||
273 | |||
274 | with on_train(): | ||
275 | for step, batch in enumerate(train_dataloader): | ||
276 | with accelerator.accumulate(model): | ||
277 | loss, acc, bsz = loss_step(step, batch) | ||
278 | |||
279 | accelerator.backward(loss) | ||
280 | |||
281 | on_before_optimize() | ||
282 | |||
283 | optimizer.step() | ||
284 | lr_scheduler.step() | ||
285 | optimizer.zero_grad(set_to_none=True) | ||
286 | |||
287 | avg_loss.update(loss.detach_(), bsz) | ||
288 | avg_acc.update(acc.detach_(), bsz) | ||
289 | |||
290 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
291 | if accelerator.sync_gradients: | ||
292 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
293 | |||
294 | local_progress_bar.update(1) | ||
295 | global_progress_bar.update(1) | ||
296 | |||
297 | global_step += 1 | ||
298 | |||
299 | logs = { | ||
300 | "train/loss": avg_loss.avg.item(), | ||
301 | "train/acc": avg_acc.avg.item(), | ||
302 | "train/cur_loss": loss.item(), | ||
303 | "train/cur_acc": acc.item(), | ||
304 | "lr": lr_scheduler.get_last_lr()[0], | ||
305 | } | ||
306 | logs.update(on_log()) | ||
307 | |||
308 | accelerator.log(logs, step=global_step) | ||
309 | |||
310 | local_progress_bar.set_postfix(**logs) | ||
311 | |||
312 | if global_step >= num_train_steps: | ||
313 | break | ||
314 | |||
315 | accelerator.wait_for_everyone() | ||
316 | |||
317 | model.eval() | ||
318 | |||
319 | cur_loss_val = AverageMeter() | ||
320 | cur_acc_val = AverageMeter() | ||
321 | |||
322 | with torch.inference_mode(): | ||
323 | with on_eval(): | ||
324 | for step, batch in enumerate(val_dataloader): | ||
325 | loss, acc, bsz = loss_step(step, batch, True) | ||
326 | |||
327 | loss = loss.detach_() | ||
328 | acc = acc.detach_() | ||
329 | |||
330 | cur_loss_val.update(loss, bsz) | ||
331 | cur_acc_val.update(acc, bsz) | ||
332 | |||
333 | avg_loss_val.update(loss, bsz) | ||
334 | avg_acc_val.update(acc, bsz) | ||
335 | |||
336 | local_progress_bar.update(1) | ||
337 | global_progress_bar.update(1) | ||
338 | |||
339 | logs = { | ||
340 | "val/loss": avg_loss_val.avg.item(), | ||
341 | "val/acc": avg_acc_val.avg.item(), | ||
342 | "val/cur_loss": loss.item(), | ||
343 | "val/cur_acc": acc.item(), | ||
344 | } | ||
345 | local_progress_bar.set_postfix(**logs) | ||
346 | |||
347 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
348 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
349 | |||
350 | accelerator.log(logs, step=global_step) | ||
351 | |||
352 | local_progress_bar.clear() | ||
353 | global_progress_bar.clear() | ||
354 | |||
355 | if accelerator.is_main_process: | ||
356 | if avg_acc_val.avg.item() > max_acc_val: | ||
357 | accelerator.print( | ||
358 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
359 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
360 | max_acc_val = avg_acc_val.avg.item() | ||
361 | |||
362 | # Create the pipeline using using the trained modules and save it. | ||
363 | if accelerator.is_main_process: | ||
364 | print("Finished!") | ||
365 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
366 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | ||
367 | accelerator.end_training() | ||
368 | |||
369 | except KeyboardInterrupt: | ||
370 | if accelerator.is_main_process: | ||
371 | print("Interrupted") | ||
372 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
373 | accelerator.end_training() | ||
374 | quit() | ||
diff --git a/training/util.py b/training/util.py index 60d64f0..0ec2032 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -55,8 +55,19 @@ class CheckpointerBase: | |||
55 | self.sample_batches = sample_batches | 55 | self.sample_batches = sample_batches |
56 | self.sample_batch_size = sample_batch_size | 56 | self.sample_batch_size = sample_batch_size |
57 | 57 | ||
58 | @torch.no_grad() | ||
59 | def checkpoint(self, step: int, postfix: str): | ||
60 | pass | ||
61 | |||
58 | @torch.inference_mode() | 62 | @torch.inference_mode() |
59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 63 | def save_samples( |
64 | self, | ||
65 | pipeline, | ||
66 | step: int, | ||
67 | num_inference_steps: int, | ||
68 | guidance_scale: float = 7.5, | ||
69 | eta: float = 0.0 | ||
70 | ): | ||
60 | samples_path = Path(self.output_dir).joinpath("samples") | 71 | samples_path = Path(self.output_dir).joinpath("samples") |
61 | 72 | ||
62 | train_data = self.datamodule.train_dataloader | 73 | train_data = self.datamodule.train_dataloader |