From 3c6ccadd3c12c54a1fa2280bce505a2dd511958a Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 07:27:45 +0100 Subject: Implemented extended Dreambooth training --- training/functional.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 5984ffb..f5c111e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -494,10 +494,11 @@ def train( text_encoder: CLIPTextModel, vae: AutoencoderKL, noise_scheduler: DDPMScheduler, - train_dataloader: DataLoader, - val_dataloader: DataLoader, dtype: torch.dtype, seed: int, + project: str, + train_dataloader: DataLoader, + val_dataloader: DataLoader, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, callbacks_fn: Callable[..., TrainingCallbacks], @@ -544,7 +545,7 @@ def train( ) if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") + accelerator.init_trackers(project) train_loop( accelerator=accelerator, -- cgit v1.2.3-70-g09d2