summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py100
1 files changed, 23 insertions, 77 deletions
diff --git a/training/functional.py b/training/functional.py
index 4ca7470..c01595a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -33,6 +33,7 @@ def const(result=None):
33@dataclass 33@dataclass
34class TrainingCallbacks(): 34class TrainingCallbacks():
35 on_prepare: Callable[[float], None] = const() 35 on_prepare: Callable[[float], None] = const()
36 on_model: Callable[[], torch.nn.Module] = const(None)
36 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
37 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
38 on_before_optimize: Callable[[int], None] = const() 39 on_before_optimize: Callable[[int], None] = const()
@@ -267,6 +268,7 @@ def loss_step(
267 noise_scheduler: DDPMScheduler, 268 noise_scheduler: DDPMScheduler,
268 unet: UNet2DConditionModel, 269 unet: UNet2DConditionModel,
269 text_encoder: CLIPTextModel, 270 text_encoder: CLIPTextModel,
271 with_prior_preservation: bool,
270 prior_loss_weight: float, 272 prior_loss_weight: float,
271 seed: int, 273 seed: int,
272 step: int, 274 step: int,
@@ -322,7 +324,7 @@ def loss_step(
322 else: 324 else:
323 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 325 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
324 326
325 if batch["with_prior"].all(): 327 if with_prior_preservation:
326 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 328 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
327 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 329 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
328 target, target_prior = torch.chunk(target, 2, dim=0) 330 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -347,7 +349,6 @@ def train_loop(
347 accelerator: Accelerator, 349 accelerator: Accelerator,
348 optimizer: torch.optim.Optimizer, 350 optimizer: torch.optim.Optimizer,
349 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 351 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
350 model: torch.nn.Module,
351 train_dataloader: DataLoader, 352 train_dataloader: DataLoader,
352 val_dataloader: DataLoader, 353 val_dataloader: DataLoader,
353 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 354 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
@@ -387,28 +388,37 @@ def train_loop(
387 ) 388 )
388 global_progress_bar.set_description("Total progress") 389 global_progress_bar.set_description("Total progress")
389 390
391 model = callbacks.on_model()
392 on_log = callbacks.on_log
393 on_train = callbacks.on_train
394 on_before_optimize = callbacks.on_before_optimize
395 on_after_optimize = callbacks.on_after_optimize
396 on_eval = callbacks.on_eval
397 on_sample = callbacks.on_sample
398 on_checkpoint = callbacks.on_checkpoint
399
390 try: 400 try:
391 for epoch in range(num_epochs): 401 for epoch in range(num_epochs):
392 if accelerator.is_main_process: 402 if accelerator.is_main_process:
393 if epoch % sample_frequency == 0: 403 if epoch % sample_frequency == 0:
394 callbacks.on_sample(global_step + global_step_offset) 404 on_sample(global_step + global_step_offset)
395 405
396 if epoch % checkpoint_frequency == 0 and epoch != 0: 406 if epoch % checkpoint_frequency == 0 and epoch != 0:
397 callbacks.on_checkpoint(global_step + global_step_offset, "training") 407 on_checkpoint(global_step + global_step_offset, "training")
398 408
399 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 409 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
400 local_progress_bar.reset() 410 local_progress_bar.reset()
401 411
402 model.train() 412 model.train()
403 413
404 with callbacks.on_train(epoch): 414 with on_train(epoch):
405 for step, batch in enumerate(train_dataloader): 415 for step, batch in enumerate(train_dataloader):
406 with accelerator.accumulate(model): 416 with accelerator.accumulate(model):
407 loss, acc, bsz = loss_step(step, batch) 417 loss, acc, bsz = loss_step(step, batch)
408 418
409 accelerator.backward(loss) 419 accelerator.backward(loss)
410 420
411 callbacks.on_before_optimize(epoch) 421 on_before_optimize(epoch)
412 422
413 optimizer.step() 423 optimizer.step()
414 lr_scheduler.step() 424 lr_scheduler.step()
@@ -419,7 +429,7 @@ def train_loop(
419 429
420 # Checks if the accelerator has performed an optimization step behind the scenes 430 # Checks if the accelerator has performed an optimization step behind the scenes
421 if accelerator.sync_gradients: 431 if accelerator.sync_gradients:
422 callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) 432 on_after_optimize(lr_scheduler.get_last_lr()[0])
423 433
424 local_progress_bar.update(1) 434 local_progress_bar.update(1)
425 global_progress_bar.update(1) 435 global_progress_bar.update(1)
@@ -433,7 +443,7 @@ def train_loop(
433 "train/cur_acc": acc.item(), 443 "train/cur_acc": acc.item(),
434 "lr": lr_scheduler.get_last_lr()[0], 444 "lr": lr_scheduler.get_last_lr()[0],
435 } 445 }
436 logs.update(callbacks.on_log()) 446 logs.update(on_log())
437 447
438 accelerator.log(logs, step=global_step) 448 accelerator.log(logs, step=global_step)
439 449
@@ -449,7 +459,7 @@ def train_loop(
449 cur_loss_val = AverageMeter() 459 cur_loss_val = AverageMeter()
450 cur_acc_val = AverageMeter() 460 cur_acc_val = AverageMeter()
451 461
452 with torch.inference_mode(), callbacks.on_eval(): 462 with torch.inference_mode(), on_eval():
453 for step, batch in enumerate(val_dataloader): 463 for step, batch in enumerate(val_dataloader):
454 loss, acc, bsz = loss_step(step, batch, True) 464 loss, acc, bsz = loss_step(step, batch, True)
455 465
@@ -485,80 +495,16 @@ def train_loop(
485 if avg_acc_val.avg.item() > max_acc_val: 495 if avg_acc_val.avg.item() > max_acc_val:
486 accelerator.print( 496 accelerator.print(
487 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 497 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
488 callbacks.on_checkpoint(global_step + global_step_offset, "milestone") 498 on_checkpoint(global_step + global_step_offset, "milestone")
489 max_acc_val = avg_acc_val.avg.item() 499 max_acc_val = avg_acc_val.avg.item()
490 500
491 # Create the pipeline using using the trained modules and save it. 501 # Create the pipeline using using the trained modules and save it.
492 if accelerator.is_main_process: 502 if accelerator.is_main_process:
493 print("Finished!") 503 print("Finished!")
494 callbacks.on_checkpoint(global_step + global_step_offset, "end") 504 on_checkpoint(global_step + global_step_offset, "end")
495 callbacks.on_sample(global_step + global_step_offset) 505 on_sample(global_step + global_step_offset)
496 accelerator.end_training()
497 506
498 except KeyboardInterrupt: 507 except KeyboardInterrupt:
499 if accelerator.is_main_process: 508 if accelerator.is_main_process:
500 print("Interrupted") 509 print("Interrupted")
501 callbacks.on_checkpoint(global_step + global_step_offset, "end") 510 on_checkpoint(global_step + global_step_offset, "end")
502 accelerator.end_training()
503
504
505def train(
506 accelerator: Accelerator,
507 unet: UNet2DConditionModel,
508 text_encoder: CLIPTextModel,
509 vae: AutoencoderKL,
510 noise_scheduler: DDPMScheduler,
511 train_dataloader: DataLoader,
512 val_dataloader: DataLoader,
513 dtype: torch.dtype,
514 seed: int,
515 optimizer: torch.optim.Optimizer,
516 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
517 num_train_epochs: int = 100,
518 sample_frequency: int = 20,
519 checkpoint_frequency: int = 50,
520 global_step_offset: int = 0,
521 prior_loss_weight: float = 0,
522 callbacks: TrainingCallbacks = TrainingCallbacks(),
523):
524 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
525 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
526 )
527
528 vae.to(accelerator.device, dtype=dtype)
529
530 for model in (unet, text_encoder, vae):
531 model.requires_grad_(False)
532 model.eval()
533
534 callbacks.on_prepare()
535
536 loss_step_ = partial(
537 loss_step,
538 vae,
539 noise_scheduler,
540 unet,
541 text_encoder,
542 prior_loss_weight,
543 seed,
544 )
545
546 if accelerator.is_main_process:
547 accelerator.init_trackers("textual_inversion")
548
549 train_loop(
550 accelerator=accelerator,
551 optimizer=optimizer,
552 lr_scheduler=lr_scheduler,
553 model=text_encoder,
554 train_dataloader=train_dataloader,
555 val_dataloader=val_dataloader,
556 loss_step=loss_step_,
557 sample_frequency=sample_frequency,
558 checkpoint_frequency=checkpoint_frequency,
559 global_step_offset=global_step_offset,
560 num_epochs=num_train_epochs,
561 callbacks=callbacks,
562 )
563
564 accelerator.free_memory()