summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 07:27:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 07:27:45 +0100
commit3c6ccadd3c12c54a1fa2280bce505a2dd511958a (patch)
tree019b9ac09acc85196ef1d09e2d968ba917ac8993 /training/functional.py
parentAdded Dreambooth strategy (diff)
downloadtextual-inversion-diff-3c6ccadd3c12c54a1fa2280bce505a2dd511958a.tar.gz
textual-inversion-diff-3c6ccadd3c12c54a1fa2280bce505a2dd511958a.tar.bz2
textual-inversion-diff-3c6ccadd3c12c54a1fa2280bce505a2dd511958a.zip
Implemented extended Dreambooth training
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index 5984ffb..f5c111e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -494,10 +494,11 @@ def train(
494 text_encoder: CLIPTextModel, 494 text_encoder: CLIPTextModel,
495 vae: AutoencoderKL, 495 vae: AutoencoderKL,
496 noise_scheduler: DDPMScheduler, 496 noise_scheduler: DDPMScheduler,
497 train_dataloader: DataLoader,
498 val_dataloader: DataLoader,
499 dtype: torch.dtype, 497 dtype: torch.dtype,
500 seed: int, 498 seed: int,
499 project: str,
500 train_dataloader: DataLoader,
501 val_dataloader: DataLoader,
501 optimizer: torch.optim.Optimizer, 502 optimizer: torch.optim.Optimizer,
502 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 503 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
503 callbacks_fn: Callable[..., TrainingCallbacks], 504 callbacks_fn: Callable[..., TrainingCallbacks],
@@ -544,7 +545,7 @@ def train(
544 ) 545 )
545 546
546 if accelerator.is_main_process: 547 if accelerator.is_main_process:
547 accelerator.init_trackers("textual_inversion") 548 accelerator.init_trackers(project)
548 549
549 train_loop( 550 train_loop(
550 accelerator=accelerator, 551 accelerator=accelerator,