summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py63
-rw-r--r--training/strategy/ti.py20
2 files changed, 39 insertions, 44 deletions
diff --git a/training/functional.py b/training/functional.py
index e54c9c8..4ca7470 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,3 +1,4 @@
1from dataclasses import dataclass
1import math 2import math
2from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union, Optional 4from typing import Callable, Any, Tuple, Union, Optional
@@ -14,6 +15,7 @@ from transformers import CLIPTextModel
14from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
15 16
16from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image
17 19
18from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
19from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 21from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
@@ -28,6 +30,18 @@ def const(result=None):
28 return fn 30 return fn
29 31
30 32
33@dataclass
34class TrainingCallbacks():
35 on_prepare: Callable[[float], None] = const()
36 on_log: Callable[[], dict[str, Any]] = const({})
37 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
38 on_before_optimize: Callable[[int], None] = const()
39 on_after_optimize: Callable[[float], None] = const()
40 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
41 on_sample: Callable[[int], None] = const()
42 on_checkpoint: Callable[[int, str], None] = const()
43
44
31def make_grid(images, rows, cols): 45def make_grid(images, rows, cols):
32 w, h = images[0].size 46 w, h = images[0].size
33 grid = Image.new('RGB', size=(cols*w, rows*h)) 47 grid = Image.new('RGB', size=(cols*w, rows*h))
@@ -341,13 +355,7 @@ def train_loop(
341 checkpoint_frequency: int = 50, 355 checkpoint_frequency: int = 50,
342 global_step_offset: int = 0, 356 global_step_offset: int = 0,
343 num_epochs: int = 100, 357 num_epochs: int = 100,
344 on_log: Callable[[], dict[str, Any]] = const({}), 358 callbacks: TrainingCallbacks = TrainingCallbacks(),
345 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
346 on_before_optimize: Callable[[int], None] = const(),
347 on_after_optimize: Callable[[float], None] = const(),
348 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
349 on_sample: Callable[[int], None] = const(),
350 on_checkpoint: Callable[[int, str], None] = const(),
351): 359):
352 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 360 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
353 num_val_steps_per_epoch = len(val_dataloader) 361 num_val_steps_per_epoch = len(val_dataloader)
@@ -383,24 +391,24 @@ def train_loop(
383 for epoch in range(num_epochs): 391 for epoch in range(num_epochs):
384 if accelerator.is_main_process: 392 if accelerator.is_main_process:
385 if epoch % sample_frequency == 0: 393 if epoch % sample_frequency == 0:
386 on_sample(global_step + global_step_offset) 394 callbacks.on_sample(global_step + global_step_offset)
387 395
388 if epoch % checkpoint_frequency == 0 and epoch != 0: 396 if epoch % checkpoint_frequency == 0 and epoch != 0:
389 on_checkpoint(global_step + global_step_offset, "training") 397 callbacks.on_checkpoint(global_step + global_step_offset, "training")
390 398
391 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 399 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
392 local_progress_bar.reset() 400 local_progress_bar.reset()
393 401
394 model.train() 402 model.train()
395 403
396 with on_train(epoch): 404 with callbacks.on_train(epoch):
397 for step, batch in enumerate(train_dataloader): 405 for step, batch in enumerate(train_dataloader):
398 with accelerator.accumulate(model): 406 with accelerator.accumulate(model):
399 loss, acc, bsz = loss_step(step, batch) 407 loss, acc, bsz = loss_step(step, batch)
400 408
401 accelerator.backward(loss) 409 accelerator.backward(loss)
402 410
403 on_before_optimize(epoch) 411 callbacks.on_before_optimize(epoch)
404 412
405 optimizer.step() 413 optimizer.step()
406 lr_scheduler.step() 414 lr_scheduler.step()
@@ -411,7 +419,7 @@ def train_loop(
411 419
412 # Checks if the accelerator has performed an optimization step behind the scenes 420 # Checks if the accelerator has performed an optimization step behind the scenes
413 if accelerator.sync_gradients: 421 if accelerator.sync_gradients:
414 on_after_optimize(lr_scheduler.get_last_lr()[0]) 422 callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0])
415 423
416 local_progress_bar.update(1) 424 local_progress_bar.update(1)
417 global_progress_bar.update(1) 425 global_progress_bar.update(1)
@@ -425,7 +433,7 @@ def train_loop(
425 "train/cur_acc": acc.item(), 433 "train/cur_acc": acc.item(),
426 "lr": lr_scheduler.get_last_lr()[0], 434 "lr": lr_scheduler.get_last_lr()[0],
427 } 435 }
428 logs.update(on_log()) 436 logs.update(callbacks.on_log())
429 437
430 accelerator.log(logs, step=global_step) 438 accelerator.log(logs, step=global_step)
431 439
@@ -441,7 +449,7 @@ def train_loop(
441 cur_loss_val = AverageMeter() 449 cur_loss_val = AverageMeter()
442 cur_acc_val = AverageMeter() 450 cur_acc_val = AverageMeter()
443 451
444 with torch.inference_mode(), on_eval(): 452 with torch.inference_mode(), callbacks.on_eval():
445 for step, batch in enumerate(val_dataloader): 453 for step, batch in enumerate(val_dataloader):
446 loss, acc, bsz = loss_step(step, batch, True) 454 loss, acc, bsz = loss_step(step, batch, True)
447 455
@@ -477,20 +485,20 @@ def train_loop(
477 if avg_acc_val.avg.item() > max_acc_val: 485 if avg_acc_val.avg.item() > max_acc_val:
478 accelerator.print( 486 accelerator.print(
479 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 487 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
480 on_checkpoint(global_step + global_step_offset, "milestone") 488 callbacks.on_checkpoint(global_step + global_step_offset, "milestone")
481 max_acc_val = avg_acc_val.avg.item() 489 max_acc_val = avg_acc_val.avg.item()
482 490
483 # Create the pipeline using using the trained modules and save it. 491 # Create the pipeline using using the trained modules and save it.
484 if accelerator.is_main_process: 492 if accelerator.is_main_process:
485 print("Finished!") 493 print("Finished!")
486 on_checkpoint(global_step + global_step_offset, "end") 494 callbacks.on_checkpoint(global_step + global_step_offset, "end")
487 on_sample(global_step + global_step_offset) 495 callbacks.on_sample(global_step + global_step_offset)
488 accelerator.end_training() 496 accelerator.end_training()
489 497
490 except KeyboardInterrupt: 498 except KeyboardInterrupt:
491 if accelerator.is_main_process: 499 if accelerator.is_main_process:
492 print("Interrupted") 500 print("Interrupted")
493 on_checkpoint(global_step + global_step_offset, "end") 501 callbacks.on_checkpoint(global_step + global_step_offset, "end")
494 accelerator.end_training() 502 accelerator.end_training()
495 503
496 504
@@ -511,14 +519,7 @@ def train(
511 checkpoint_frequency: int = 50, 519 checkpoint_frequency: int = 50,
512 global_step_offset: int = 0, 520 global_step_offset: int = 0,
513 prior_loss_weight: float = 0, 521 prior_loss_weight: float = 0,
514 on_prepare: Callable[[], dict[str, Any]] = const({}), 522 callbacks: TrainingCallbacks = TrainingCallbacks(),
515 on_log: Callable[[], dict[str, Any]] = const({}),
516 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
517 on_before_optimize: Callable[[int], None] = const(),
518 on_after_optimize: Callable[[float], None] = const(),
519 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
520 on_sample: Callable[[int], None] = const(),
521 on_checkpoint: Callable[[int, str], None] = const(),
522): 523):
523 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 524 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
524 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 525 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
@@ -530,7 +531,7 @@ def train(
530 model.requires_grad_(False) 531 model.requires_grad_(False)
531 model.eval() 532 model.eval()
532 533
533 on_prepare() 534 callbacks.on_prepare()
534 535
535 loss_step_ = partial( 536 loss_step_ = partial(
536 loss_step, 537 loss_step,
@@ -557,13 +558,7 @@ def train(
557 checkpoint_frequency=checkpoint_frequency, 558 checkpoint_frequency=checkpoint_frequency,
558 global_step_offset=global_step_offset, 559 global_step_offset=global_step_offset,
559 num_epochs=num_train_epochs, 560 num_epochs=num_train_epochs,
560 on_log=on_log, 561 callbacks=callbacks,
561 on_train=on_train,
562 on_before_optimize=on_before_optimize,
563 on_after_optimize=on_after_optimize,
564 on_eval=on_eval,
565 on_sample=on_sample,
566 on_checkpoint=on_checkpoint,
567 ) 562 )
568 563
569 accelerator.free_memory() 564 accelerator.free_memory()
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 83dc566..6f8384f 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,7 +15,7 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import save_samples 18from training.functional import TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy(
@@ -153,12 +153,12 @@ def textual_inversion_strategy(
153 with ema_context: 153 with ema_context:
154 save_samples_(step=step) 154 save_samples_(step=step)
155 155
156 return { 156 return TrainingCallbacks(
157 "on_prepare": on_prepare, 157 on_prepare=on_prepare,
158 "on_train": on_train, 158 on_train=on_train,
159 "on_eval": on_eval, 159 on_eval=on_eval,
160 "on_after_optimize": on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 "on_log": on_log, 161 on_log=on_log,
162 "on_checkpoint": on_checkpoint, 162 on_checkpoint=on_checkpoint,
163 "on_sample": on_sample, 163 on_sample=on_sample,
164 } 164 )