diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 75 |
1 files changed, 74 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index c5b514a..1f2ca6d 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import math | 1 | import math |
2 | from contextlib import _GeneratorContextManager, nullcontext | 2 | from contextlib import _GeneratorContextManager, nullcontext |
3 | from typing import Callable, Any, Tuple, Union | 3 | from typing import Callable, Any, Tuple, Union, Optional |
4 | from functools import partial | ||
4 | 5 | ||
5 | import torch | 6 | import torch |
6 | import torch.nn.functional as F | 7 | import torch.nn.functional as F |
@@ -376,3 +377,75 @@ def train_loop( | |||
376 | print("Interrupted") | 377 | print("Interrupted") |
377 | on_checkpoint(global_step + global_step_offset, "end") | 378 | on_checkpoint(global_step + global_step_offset, "end") |
378 | accelerator.end_training() | 379 | accelerator.end_training() |
380 | |||
381 | |||
382 | def train( | ||
383 | accelerator: Accelerator, | ||
384 | unet: UNet2DConditionModel, | ||
385 | text_encoder: CLIPTextModel, | ||
386 | vae: AutoencoderKL, | ||
387 | noise_scheduler: DDPMScheduler, | ||
388 | train_dataloader: DataLoader, | ||
389 | val_dataloader: DataLoader, | ||
390 | dtype: torch.dtype, | ||
391 | seed: int, | ||
392 | optimizer: torch.optim.Optimizer, | ||
393 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
394 | num_train_epochs: int = 100, | ||
395 | sample_frequency: int = 20, | ||
396 | checkpoint_frequency: int = 50, | ||
397 | global_step_offset: int = 0, | ||
398 | prior_loss_weight: float = 0, | ||
399 | on_prepare: Callable[[], dict[str, Any]] = const({}), | ||
400 | on_log: Callable[[], dict[str, Any]] = const({}), | ||
401 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
402 | on_before_optimize: Callable[[int], None] = const(), | ||
403 | on_after_optimize: Callable[[float], None] = const(), | ||
404 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), | ||
405 | on_sample: Callable[[int], None] = const(), | ||
406 | on_checkpoint: Callable[[int, str], None] = const(), | ||
407 | ): | ||
408 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
409 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
410 | ) | ||
411 | |||
412 | vae.to(accelerator.device, dtype=dtype) | ||
413 | |||
414 | for model in (unet, text_encoder, vae): | ||
415 | model.requires_grad_(False) | ||
416 | model.eval() | ||
417 | |||
418 | on_prepare() | ||
419 | |||
420 | loss_step_ = partial( | ||
421 | loss_step, | ||
422 | vae, | ||
423 | noise_scheduler, | ||
424 | unet, | ||
425 | text_encoder, | ||
426 | prior_loss_weight, | ||
427 | seed, | ||
428 | ) | ||
429 | |||
430 | train_loop( | ||
431 | accelerator=accelerator, | ||
432 | optimizer=optimizer, | ||
433 | lr_scheduler=lr_scheduler, | ||
434 | model=text_encoder, | ||
435 | train_dataloader=train_dataloader, | ||
436 | val_dataloader=val_dataloader, | ||
437 | loss_step=loss_step_, | ||
438 | sample_frequency=sample_frequency, | ||
439 | checkpoint_frequency=checkpoint_frequency, | ||
440 | global_step_offset=global_step_offset, | ||
441 | num_epochs=num_train_epochs, | ||
442 | on_log=on_log, | ||
443 | on_train=on_train, | ||
444 | on_before_optimize=on_before_optimize, | ||
445 | on_after_optimize=on_after_optimize, | ||
446 | on_eval=on_eval, | ||
447 | on_sample=on_sample, | ||
448 | on_checkpoint=on_checkpoint, | ||
449 | ) | ||
450 | |||
451 | accelerator.free_memory() | ||