diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 63 | ||||
-rw-r--r-- | training/strategy/ti.py | 20 |
2 files changed, 39 insertions, 44 deletions
diff --git a/training/functional.py b/training/functional.py index e54c9c8..4ca7470 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -1,3 +1,4 @@ | |||
1 | from dataclasses import dataclass | ||
1 | import math | 2 | import math |
2 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
3 | from typing import Callable, Any, Tuple, Union, Optional | 4 | from typing import Callable, Any, Tuple, Union, Optional |
@@ -14,6 +15,7 @@ from transformers import CLIPTextModel | |||
14 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | 15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler |
15 | 16 | ||
16 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | ||
17 | 19 | ||
18 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
19 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
@@ -28,6 +30,18 @@ def const(result=None): | |||
28 | return fn | 30 | return fn |
29 | 31 | ||
30 | 32 | ||
33 | @dataclass | ||
34 | class TrainingCallbacks(): | ||
35 | on_prepare: Callable[[float], None] = const() | ||
36 | on_log: Callable[[], dict[str, Any]] = const({}) | ||
37 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | ||
38 | on_before_optimize: Callable[[int], None] = const() | ||
39 | on_after_optimize: Callable[[float], None] = const() | ||
40 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | ||
41 | on_sample: Callable[[int], None] = const() | ||
42 | on_checkpoint: Callable[[int, str], None] = const() | ||
43 | |||
44 | |||
31 | def make_grid(images, rows, cols): | 45 | def make_grid(images, rows, cols): |
32 | w, h = images[0].size | 46 | w, h = images[0].size |
33 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 47 | grid = Image.new('RGB', size=(cols*w, rows*h)) |
@@ -341,13 +355,7 @@ def train_loop( | |||
341 | checkpoint_frequency: int = 50, | 355 | checkpoint_frequency: int = 50, |
342 | global_step_offset: int = 0, | 356 | global_step_offset: int = 0, |
343 | num_epochs: int = 100, | 357 | num_epochs: int = 100, |
344 | on_log: Callable[[], dict[str, Any]] = const({}), | 358 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
345 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
346 | on_before_optimize: Callable[[int], None] = const(), | ||
347 | on_after_optimize: Callable[[float], None] = const(), | ||
348 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), | ||
349 | on_sample: Callable[[int], None] = const(), | ||
350 | on_checkpoint: Callable[[int, str], None] = const(), | ||
351 | ): | 359 | ): |
352 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 360 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
353 | num_val_steps_per_epoch = len(val_dataloader) | 361 | num_val_steps_per_epoch = len(val_dataloader) |
@@ -383,24 +391,24 @@ def train_loop( | |||
383 | for epoch in range(num_epochs): | 391 | for epoch in range(num_epochs): |
384 | if accelerator.is_main_process: | 392 | if accelerator.is_main_process: |
385 | if epoch % sample_frequency == 0: | 393 | if epoch % sample_frequency == 0: |
386 | on_sample(global_step + global_step_offset) | 394 | callbacks.on_sample(global_step + global_step_offset) |
387 | 395 | ||
388 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 396 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
389 | on_checkpoint(global_step + global_step_offset, "training") | 397 | callbacks.on_checkpoint(global_step + global_step_offset, "training") |
390 | 398 | ||
391 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 399 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
392 | local_progress_bar.reset() | 400 | local_progress_bar.reset() |
393 | 401 | ||
394 | model.train() | 402 | model.train() |
395 | 403 | ||
396 | with on_train(epoch): | 404 | with callbacks.on_train(epoch): |
397 | for step, batch in enumerate(train_dataloader): | 405 | for step, batch in enumerate(train_dataloader): |
398 | with accelerator.accumulate(model): | 406 | with accelerator.accumulate(model): |
399 | loss, acc, bsz = loss_step(step, batch) | 407 | loss, acc, bsz = loss_step(step, batch) |
400 | 408 | ||
401 | accelerator.backward(loss) | 409 | accelerator.backward(loss) |
402 | 410 | ||
403 | on_before_optimize(epoch) | 411 | callbacks.on_before_optimize(epoch) |
404 | 412 | ||
405 | optimizer.step() | 413 | optimizer.step() |
406 | lr_scheduler.step() | 414 | lr_scheduler.step() |
@@ -411,7 +419,7 @@ def train_loop( | |||
411 | 419 | ||
412 | # Checks if the accelerator has performed an optimization step behind the scenes | 420 | # Checks if the accelerator has performed an optimization step behind the scenes |
413 | if accelerator.sync_gradients: | 421 | if accelerator.sync_gradients: |
414 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 422 | callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) |
415 | 423 | ||
416 | local_progress_bar.update(1) | 424 | local_progress_bar.update(1) |
417 | global_progress_bar.update(1) | 425 | global_progress_bar.update(1) |
@@ -425,7 +433,7 @@ def train_loop( | |||
425 | "train/cur_acc": acc.item(), | 433 | "train/cur_acc": acc.item(), |
426 | "lr": lr_scheduler.get_last_lr()[0], | 434 | "lr": lr_scheduler.get_last_lr()[0], |
427 | } | 435 | } |
428 | logs.update(on_log()) | 436 | logs.update(callbacks.on_log()) |
429 | 437 | ||
430 | accelerator.log(logs, step=global_step) | 438 | accelerator.log(logs, step=global_step) |
431 | 439 | ||
@@ -441,7 +449,7 @@ def train_loop( | |||
441 | cur_loss_val = AverageMeter() | 449 | cur_loss_val = AverageMeter() |
442 | cur_acc_val = AverageMeter() | 450 | cur_acc_val = AverageMeter() |
443 | 451 | ||
444 | with torch.inference_mode(), on_eval(): | 452 | with torch.inference_mode(), callbacks.on_eval(): |
445 | for step, batch in enumerate(val_dataloader): | 453 | for step, batch in enumerate(val_dataloader): |
446 | loss, acc, bsz = loss_step(step, batch, True) | 454 | loss, acc, bsz = loss_step(step, batch, True) |
447 | 455 | ||
@@ -477,20 +485,20 @@ def train_loop( | |||
477 | if avg_acc_val.avg.item() > max_acc_val: | 485 | if avg_acc_val.avg.item() > max_acc_val: |
478 | accelerator.print( | 486 | accelerator.print( |
479 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 487 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
480 | on_checkpoint(global_step + global_step_offset, "milestone") | 488 | callbacks.on_checkpoint(global_step + global_step_offset, "milestone") |
481 | max_acc_val = avg_acc_val.avg.item() | 489 | max_acc_val = avg_acc_val.avg.item() |
482 | 490 | ||
483 | # Create the pipeline using using the trained modules and save it. | 491 | # Create the pipeline using using the trained modules and save it. |
484 | if accelerator.is_main_process: | 492 | if accelerator.is_main_process: |
485 | print("Finished!") | 493 | print("Finished!") |
486 | on_checkpoint(global_step + global_step_offset, "end") | 494 | callbacks.on_checkpoint(global_step + global_step_offset, "end") |
487 | on_sample(global_step + global_step_offset) | 495 | callbacks.on_sample(global_step + global_step_offset) |
488 | accelerator.end_training() | 496 | accelerator.end_training() |
489 | 497 | ||
490 | except KeyboardInterrupt: | 498 | except KeyboardInterrupt: |
491 | if accelerator.is_main_process: | 499 | if accelerator.is_main_process: |
492 | print("Interrupted") | 500 | print("Interrupted") |
493 | on_checkpoint(global_step + global_step_offset, "end") | 501 | callbacks.on_checkpoint(global_step + global_step_offset, "end") |
494 | accelerator.end_training() | 502 | accelerator.end_training() |
495 | 503 | ||
496 | 504 | ||
@@ -511,14 +519,7 @@ def train( | |||
511 | checkpoint_frequency: int = 50, | 519 | checkpoint_frequency: int = 50, |
512 | global_step_offset: int = 0, | 520 | global_step_offset: int = 0, |
513 | prior_loss_weight: float = 0, | 521 | prior_loss_weight: float = 0, |
514 | on_prepare: Callable[[], dict[str, Any]] = const({}), | 522 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
515 | on_log: Callable[[], dict[str, Any]] = const({}), | ||
516 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
517 | on_before_optimize: Callable[[int], None] = const(), | ||
518 | on_after_optimize: Callable[[float], None] = const(), | ||
519 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), | ||
520 | on_sample: Callable[[int], None] = const(), | ||
521 | on_checkpoint: Callable[[int, str], None] = const(), | ||
522 | ): | 523 | ): |
523 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 524 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
524 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 525 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
@@ -530,7 +531,7 @@ def train( | |||
530 | model.requires_grad_(False) | 531 | model.requires_grad_(False) |
531 | model.eval() | 532 | model.eval() |
532 | 533 | ||
533 | on_prepare() | 534 | callbacks.on_prepare() |
534 | 535 | ||
535 | loss_step_ = partial( | 536 | loss_step_ = partial( |
536 | loss_step, | 537 | loss_step, |
@@ -557,13 +558,7 @@ def train( | |||
557 | checkpoint_frequency=checkpoint_frequency, | 558 | checkpoint_frequency=checkpoint_frequency, |
558 | global_step_offset=global_step_offset, | 559 | global_step_offset=global_step_offset, |
559 | num_epochs=num_train_epochs, | 560 | num_epochs=num_train_epochs, |
560 | on_log=on_log, | 561 | callbacks=callbacks, |
561 | on_train=on_train, | ||
562 | on_before_optimize=on_before_optimize, | ||
563 | on_after_optimize=on_after_optimize, | ||
564 | on_eval=on_eval, | ||
565 | on_sample=on_sample, | ||
566 | on_checkpoint=on_checkpoint, | ||
567 | ) | 562 | ) |
568 | 563 | ||
569 | accelerator.free_memory() | 564 | accelerator.free_memory() |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 83dc566..6f8384f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -15,7 +15,7 @@ from slugify import slugify | |||
15 | 15 | ||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
18 | from training.functional import save_samples | 18 | from training.functional import TrainingCallbacks, save_samples |
19 | 19 | ||
20 | 20 | ||
21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy( |
@@ -153,12 +153,12 @@ def textual_inversion_strategy( | |||
153 | with ema_context: | 153 | with ema_context: |
154 | save_samples_(step=step) | 154 | save_samples_(step=step) |
155 | 155 | ||
156 | return { | 156 | return TrainingCallbacks( |
157 | "on_prepare": on_prepare, | 157 | on_prepare=on_prepare, |
158 | "on_train": on_train, | 158 | on_train=on_train, |
159 | "on_eval": on_eval, | 159 | on_eval=on_eval, |
160 | "on_after_optimize": on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
161 | "on_log": on_log, | 161 | on_log=on_log, |
162 | "on_checkpoint": on_checkpoint, | 162 | on_checkpoint=on_checkpoint, |
163 | "on_sample": on_sample, | 163 | on_sample=on_sample, |
164 | } | 164 | ) |