diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 63 | ||||
| -rw-r--r-- | training/strategy/ti.py | 20 |
2 files changed, 39 insertions, 44 deletions
diff --git a/training/functional.py b/training/functional.py index e54c9c8..4ca7470 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -1,3 +1,4 @@ | |||
| 1 | from dataclasses import dataclass | ||
| 1 | import math | 2 | import math |
| 2 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
| 3 | from typing import Callable, Any, Tuple, Union, Optional | 4 | from typing import Callable, Any, Tuple, Union, Optional |
| @@ -14,6 +15,7 @@ from transformers import CLIPTextModel | |||
| 14 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | 15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler |
| 15 | 16 | ||
| 16 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | ||
| 17 | 19 | ||
| 18 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 19 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| @@ -28,6 +30,18 @@ def const(result=None): | |||
| 28 | return fn | 30 | return fn |
| 29 | 31 | ||
| 30 | 32 | ||
| 33 | @dataclass | ||
| 34 | class TrainingCallbacks(): | ||
| 35 | on_prepare: Callable[[float], None] = const() | ||
| 36 | on_log: Callable[[], dict[str, Any]] = const({}) | ||
| 37 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | ||
| 38 | on_before_optimize: Callable[[int], None] = const() | ||
| 39 | on_after_optimize: Callable[[float], None] = const() | ||
| 40 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | ||
| 41 | on_sample: Callable[[int], None] = const() | ||
| 42 | on_checkpoint: Callable[[int, str], None] = const() | ||
| 43 | |||
| 44 | |||
| 31 | def make_grid(images, rows, cols): | 45 | def make_grid(images, rows, cols): |
| 32 | w, h = images[0].size | 46 | w, h = images[0].size |
| 33 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 47 | grid = Image.new('RGB', size=(cols*w, rows*h)) |
| @@ -341,13 +355,7 @@ def train_loop( | |||
| 341 | checkpoint_frequency: int = 50, | 355 | checkpoint_frequency: int = 50, |
| 342 | global_step_offset: int = 0, | 356 | global_step_offset: int = 0, |
| 343 | num_epochs: int = 100, | 357 | num_epochs: int = 100, |
| 344 | on_log: Callable[[], dict[str, Any]] = const({}), | 358 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 345 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
| 346 | on_before_optimize: Callable[[int], None] = const(), | ||
| 347 | on_after_optimize: Callable[[float], None] = const(), | ||
| 348 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), | ||
| 349 | on_sample: Callable[[int], None] = const(), | ||
| 350 | on_checkpoint: Callable[[int, str], None] = const(), | ||
| 351 | ): | 359 | ): |
| 352 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 360 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
| 353 | num_val_steps_per_epoch = len(val_dataloader) | 361 | num_val_steps_per_epoch = len(val_dataloader) |
| @@ -383,24 +391,24 @@ def train_loop( | |||
| 383 | for epoch in range(num_epochs): | 391 | for epoch in range(num_epochs): |
| 384 | if accelerator.is_main_process: | 392 | if accelerator.is_main_process: |
| 385 | if epoch % sample_frequency == 0: | 393 | if epoch % sample_frequency == 0: |
| 386 | on_sample(global_step + global_step_offset) | 394 | callbacks.on_sample(global_step + global_step_offset) |
| 387 | 395 | ||
| 388 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 396 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 389 | on_checkpoint(global_step + global_step_offset, "training") | 397 | callbacks.on_checkpoint(global_step + global_step_offset, "training") |
| 390 | 398 | ||
| 391 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 399 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 392 | local_progress_bar.reset() | 400 | local_progress_bar.reset() |
| 393 | 401 | ||
| 394 | model.train() | 402 | model.train() |
| 395 | 403 | ||
| 396 | with on_train(epoch): | 404 | with callbacks.on_train(epoch): |
| 397 | for step, batch in enumerate(train_dataloader): | 405 | for step, batch in enumerate(train_dataloader): |
| 398 | with accelerator.accumulate(model): | 406 | with accelerator.accumulate(model): |
| 399 | loss, acc, bsz = loss_step(step, batch) | 407 | loss, acc, bsz = loss_step(step, batch) |
| 400 | 408 | ||
| 401 | accelerator.backward(loss) | 409 | accelerator.backward(loss) |
| 402 | 410 | ||
| 403 | on_before_optimize(epoch) | 411 | callbacks.on_before_optimize(epoch) |
| 404 | 412 | ||
| 405 | optimizer.step() | 413 | optimizer.step() |
| 406 | lr_scheduler.step() | 414 | lr_scheduler.step() |
| @@ -411,7 +419,7 @@ def train_loop( | |||
| 411 | 419 | ||
| 412 | # Checks if the accelerator has performed an optimization step behind the scenes | 420 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 413 | if accelerator.sync_gradients: | 421 | if accelerator.sync_gradients: |
| 414 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | 422 | callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) |
| 415 | 423 | ||
| 416 | local_progress_bar.update(1) | 424 | local_progress_bar.update(1) |
| 417 | global_progress_bar.update(1) | 425 | global_progress_bar.update(1) |
| @@ -425,7 +433,7 @@ def train_loop( | |||
| 425 | "train/cur_acc": acc.item(), | 433 | "train/cur_acc": acc.item(), |
| 426 | "lr": lr_scheduler.get_last_lr()[0], | 434 | "lr": lr_scheduler.get_last_lr()[0], |
| 427 | } | 435 | } |
| 428 | logs.update(on_log()) | 436 | logs.update(callbacks.on_log()) |
| 429 | 437 | ||
| 430 | accelerator.log(logs, step=global_step) | 438 | accelerator.log(logs, step=global_step) |
| 431 | 439 | ||
| @@ -441,7 +449,7 @@ def train_loop( | |||
| 441 | cur_loss_val = AverageMeter() | 449 | cur_loss_val = AverageMeter() |
| 442 | cur_acc_val = AverageMeter() | 450 | cur_acc_val = AverageMeter() |
| 443 | 451 | ||
| 444 | with torch.inference_mode(), on_eval(): | 452 | with torch.inference_mode(), callbacks.on_eval(): |
| 445 | for step, batch in enumerate(val_dataloader): | 453 | for step, batch in enumerate(val_dataloader): |
| 446 | loss, acc, bsz = loss_step(step, batch, True) | 454 | loss, acc, bsz = loss_step(step, batch, True) |
| 447 | 455 | ||
| @@ -477,20 +485,20 @@ def train_loop( | |||
| 477 | if avg_acc_val.avg.item() > max_acc_val: | 485 | if avg_acc_val.avg.item() > max_acc_val: |
| 478 | accelerator.print( | 486 | accelerator.print( |
| 479 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 487 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 480 | on_checkpoint(global_step + global_step_offset, "milestone") | 488 | callbacks.on_checkpoint(global_step + global_step_offset, "milestone") |
| 481 | max_acc_val = avg_acc_val.avg.item() | 489 | max_acc_val = avg_acc_val.avg.item() |
| 482 | 490 | ||
| 483 | # Create the pipeline using using the trained modules and save it. | 491 | # Create the pipeline using using the trained modules and save it. |
| 484 | if accelerator.is_main_process: | 492 | if accelerator.is_main_process: |
| 485 | print("Finished!") | 493 | print("Finished!") |
| 486 | on_checkpoint(global_step + global_step_offset, "end") | 494 | callbacks.on_checkpoint(global_step + global_step_offset, "end") |
| 487 | on_sample(global_step + global_step_offset) | 495 | callbacks.on_sample(global_step + global_step_offset) |
| 488 | accelerator.end_training() | 496 | accelerator.end_training() |
| 489 | 497 | ||
| 490 | except KeyboardInterrupt: | 498 | except KeyboardInterrupt: |
| 491 | if accelerator.is_main_process: | 499 | if accelerator.is_main_process: |
| 492 | print("Interrupted") | 500 | print("Interrupted") |
| 493 | on_checkpoint(global_step + global_step_offset, "end") | 501 | callbacks.on_checkpoint(global_step + global_step_offset, "end") |
| 494 | accelerator.end_training() | 502 | accelerator.end_training() |
| 495 | 503 | ||
| 496 | 504 | ||
| @@ -511,14 +519,7 @@ def train( | |||
| 511 | checkpoint_frequency: int = 50, | 519 | checkpoint_frequency: int = 50, |
| 512 | global_step_offset: int = 0, | 520 | global_step_offset: int = 0, |
| 513 | prior_loss_weight: float = 0, | 521 | prior_loss_weight: float = 0, |
| 514 | on_prepare: Callable[[], dict[str, Any]] = const({}), | 522 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 515 | on_log: Callable[[], dict[str, Any]] = const({}), | ||
| 516 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
| 517 | on_before_optimize: Callable[[int], None] = const(), | ||
| 518 | on_after_optimize: Callable[[float], None] = const(), | ||
| 519 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), | ||
| 520 | on_sample: Callable[[int], None] = const(), | ||
| 521 | on_checkpoint: Callable[[int, str], None] = const(), | ||
| 522 | ): | 523 | ): |
| 523 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 524 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 524 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 525 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| @@ -530,7 +531,7 @@ def train( | |||
| 530 | model.requires_grad_(False) | 531 | model.requires_grad_(False) |
| 531 | model.eval() | 532 | model.eval() |
| 532 | 533 | ||
| 533 | on_prepare() | 534 | callbacks.on_prepare() |
| 534 | 535 | ||
| 535 | loss_step_ = partial( | 536 | loss_step_ = partial( |
| 536 | loss_step, | 537 | loss_step, |
| @@ -557,13 +558,7 @@ def train( | |||
| 557 | checkpoint_frequency=checkpoint_frequency, | 558 | checkpoint_frequency=checkpoint_frequency, |
| 558 | global_step_offset=global_step_offset, | 559 | global_step_offset=global_step_offset, |
| 559 | num_epochs=num_train_epochs, | 560 | num_epochs=num_train_epochs, |
| 560 | on_log=on_log, | 561 | callbacks=callbacks, |
| 561 | on_train=on_train, | ||
| 562 | on_before_optimize=on_before_optimize, | ||
| 563 | on_after_optimize=on_after_optimize, | ||
| 564 | on_eval=on_eval, | ||
| 565 | on_sample=on_sample, | ||
| 566 | on_checkpoint=on_checkpoint, | ||
| 567 | ) | 562 | ) |
| 568 | 563 | ||
| 569 | accelerator.free_memory() | 564 | accelerator.free_memory() |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 83dc566..6f8384f 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -15,7 +15,7 @@ from slugify import slugify | |||
| 15 | 15 | ||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
| 18 | from training.functional import save_samples | 18 | from training.functional import TrainingCallbacks, save_samples |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy( |
| @@ -153,12 +153,12 @@ def textual_inversion_strategy( | |||
| 153 | with ema_context: | 153 | with ema_context: |
| 154 | save_samples_(step=step) | 154 | save_samples_(step=step) |
| 155 | 155 | ||
| 156 | return { | 156 | return TrainingCallbacks( |
| 157 | "on_prepare": on_prepare, | 157 | on_prepare=on_prepare, |
| 158 | "on_train": on_train, | 158 | on_train=on_train, |
| 159 | "on_eval": on_eval, | 159 | on_eval=on_eval, |
| 160 | "on_after_optimize": on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
| 161 | "on_log": on_log, | 161 | on_log=on_log, |
| 162 | "on_checkpoint": on_checkpoint, | 162 | on_checkpoint=on_checkpoint, |
| 163 | "on_sample": on_sample, | 163 | on_sample=on_sample, |
| 164 | } | 164 | ) |
