diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 5984ffb..f5c111e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -494,10 +494,11 @@ def train( | |||
494 | text_encoder: CLIPTextModel, | 494 | text_encoder: CLIPTextModel, |
495 | vae: AutoencoderKL, | 495 | vae: AutoencoderKL, |
496 | noise_scheduler: DDPMScheduler, | 496 | noise_scheduler: DDPMScheduler, |
497 | train_dataloader: DataLoader, | ||
498 | val_dataloader: DataLoader, | ||
499 | dtype: torch.dtype, | 497 | dtype: torch.dtype, |
500 | seed: int, | 498 | seed: int, |
499 | project: str, | ||
500 | train_dataloader: DataLoader, | ||
501 | val_dataloader: DataLoader, | ||
501 | optimizer: torch.optim.Optimizer, | 502 | optimizer: torch.optim.Optimizer, |
502 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 503 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
503 | callbacks_fn: Callable[..., TrainingCallbacks], | 504 | callbacks_fn: Callable[..., TrainingCallbacks], |
@@ -544,7 +545,7 @@ def train( | |||
544 | ) | 545 | ) |
545 | 546 | ||
546 | if accelerator.is_main_process: | 547 | if accelerator.is_main_process: |
547 | accelerator.init_trackers("textual_inversion") | 548 | accelerator.init_trackers(project) |
548 | 549 | ||
549 | train_loop( | 550 | train_loop( |
550 | accelerator=accelerator, | 551 | accelerator=accelerator, |