diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 93 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 54 insertions, 43 deletions
diff --git a/training/functional.py b/training/functional.py index 1b6162b..c6b4dc3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -73,7 +73,7 @@ def save_samples( | |||
| 73 | vae: AutoencoderKL, | 73 | vae: AutoencoderKL, |
| 74 | sample_scheduler: DPMSolverMultistepScheduler, | 74 | sample_scheduler: DPMSolverMultistepScheduler, |
| 75 | train_dataloader: DataLoader, | 75 | train_dataloader: DataLoader, |
| 76 | val_dataloader: DataLoader, | 76 | val_dataloader: Optional[DataLoader], |
| 77 | dtype: torch.dtype, | 77 | dtype: torch.dtype, |
| 78 | output_dir: Path, | 78 | output_dir: Path, |
| 79 | seed: int, | 79 | seed: int, |
| @@ -111,11 +111,13 @@ def save_samples( | |||
| 111 | 111 | ||
| 112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
| 113 | 113 | ||
| 114 | for pool, data, gen in [ | 114 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] |
| 115 | ("stable", val_dataloader, generator), | 115 | |
| 116 | ("val", val_dataloader, None), | 116 | if val_dataloader is not None: |
| 117 | ("train", train_dataloader, None) | 117 | datasets.append(("stable", val_dataloader, generator)) |
| 118 | ]: | 118 | datasets.append(("val", val_dataloader, None)) |
| 119 | |||
| 120 | for pool, data, gen in datasets: | ||
| 119 | all_samples = [] | 121 | all_samples = [] |
| 120 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 122 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 121 | file_path.parent.mkdir(parents=True, exist_ok=True) | 123 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| @@ -328,7 +330,7 @@ def train_loop( | |||
| 328 | optimizer: torch.optim.Optimizer, | 330 | optimizer: torch.optim.Optimizer, |
| 329 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 331 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 330 | train_dataloader: DataLoader, | 332 | train_dataloader: DataLoader, |
| 331 | val_dataloader: DataLoader, | 333 | val_dataloader: Optional[DataLoader], |
| 332 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 334 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 333 | sample_frequency: int = 10, | 335 | sample_frequency: int = 10, |
| 334 | checkpoint_frequency: int = 50, | 336 | checkpoint_frequency: int = 50, |
| @@ -337,7 +339,7 @@ def train_loop( | |||
| 337 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 339 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 338 | ): | 340 | ): |
| 339 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 341 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
| 340 | num_val_steps_per_epoch = len(val_dataloader) | 342 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 |
| 341 | 343 | ||
| 342 | num_training_steps = num_training_steps_per_epoch * num_epochs | 344 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 343 | num_val_steps = num_val_steps_per_epoch * num_epochs | 345 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| @@ -350,6 +352,7 @@ def train_loop( | |||
| 350 | avg_loss_val = AverageMeter() | 352 | avg_loss_val = AverageMeter() |
| 351 | avg_acc_val = AverageMeter() | 353 | avg_acc_val = AverageMeter() |
| 352 | 354 | ||
| 355 | max_acc = 0.0 | ||
| 353 | max_acc_val = 0.0 | 356 | max_acc_val = 0.0 |
| 354 | 357 | ||
| 355 | local_progress_bar = tqdm( | 358 | local_progress_bar = tqdm( |
| @@ -432,49 +435,57 @@ def train_loop( | |||
| 432 | 435 | ||
| 433 | accelerator.wait_for_everyone() | 436 | accelerator.wait_for_everyone() |
| 434 | 437 | ||
| 435 | model.eval() | 438 | if val_dataloader is not None: |
| 439 | model.eval() | ||
| 436 | 440 | ||
| 437 | cur_loss_val = AverageMeter() | 441 | cur_loss_val = AverageMeter() |
| 438 | cur_acc_val = AverageMeter() | 442 | cur_acc_val = AverageMeter() |
| 439 | 443 | ||
| 440 | with torch.inference_mode(), on_eval(): | 444 | with torch.inference_mode(), on_eval(): |
| 441 | for step, batch in enumerate(val_dataloader): | 445 | for step, batch in enumerate(val_dataloader): |
| 442 | loss, acc, bsz = loss_step(step, batch, True) | 446 | loss, acc, bsz = loss_step(step, batch, True) |
| 443 | 447 | ||
| 444 | loss = loss.detach_() | 448 | loss = loss.detach_() |
| 445 | acc = acc.detach_() | 449 | acc = acc.detach_() |
| 446 | 450 | ||
| 447 | cur_loss_val.update(loss, bsz) | 451 | cur_loss_val.update(loss, bsz) |
| 448 | cur_acc_val.update(acc, bsz) | 452 | cur_acc_val.update(acc, bsz) |
| 449 | 453 | ||
| 450 | avg_loss_val.update(loss, bsz) | 454 | avg_loss_val.update(loss, bsz) |
| 451 | avg_acc_val.update(acc, bsz) | 455 | avg_acc_val.update(acc, bsz) |
| 452 | 456 | ||
| 453 | local_progress_bar.update(1) | 457 | local_progress_bar.update(1) |
| 454 | global_progress_bar.update(1) | 458 | global_progress_bar.update(1) |
| 455 | 459 | ||
| 456 | logs = { | 460 | logs = { |
| 457 | "val/loss": avg_loss_val.avg.item(), | 461 | "val/loss": avg_loss_val.avg.item(), |
| 458 | "val/acc": avg_acc_val.avg.item(), | 462 | "val/acc": avg_acc_val.avg.item(), |
| 459 | "val/cur_loss": loss.item(), | 463 | "val/cur_loss": loss.item(), |
| 460 | "val/cur_acc": acc.item(), | 464 | "val/cur_acc": acc.item(), |
| 461 | } | 465 | } |
| 462 | local_progress_bar.set_postfix(**logs) | 466 | local_progress_bar.set_postfix(**logs) |
| 463 | 467 | ||
| 464 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 468 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 465 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 469 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
| 466 | 470 | ||
| 467 | accelerator.log(logs, step=global_step) | 471 | accelerator.log(logs, step=global_step) |
| 468 | 472 | ||
| 469 | local_progress_bar.clear() | 473 | local_progress_bar.clear() |
| 470 | global_progress_bar.clear() | 474 | global_progress_bar.clear() |
| 471 | 475 | ||
| 472 | if accelerator.is_main_process: | 476 | if accelerator.is_main_process: |
| 473 | if avg_acc_val.avg.item() > max_acc_val: | 477 | if avg_acc_val.avg.item() > max_acc_val: |
| 474 | accelerator.print( | 478 | accelerator.print( |
| 475 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 479 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 476 | on_checkpoint(global_step + global_step_offset, "milestone") | 480 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 477 | max_acc_val = avg_acc_val.avg.item() | 481 | max_acc_val = avg_acc_val.avg.item() |
| 482 | else: | ||
| 483 | if accelerator.is_main_process: | ||
| 484 | if avg_acc.avg.item() > max_acc: | ||
| 485 | accelerator.print( | ||
| 486 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | ||
| 487 | on_checkpoint(global_step + global_step_offset, "milestone") | ||
| 488 | max_acc = avg_acc.avg.item() | ||
| 478 | 489 | ||
| 479 | # Create the pipeline using using the trained modules and save it. | 490 | # Create the pipeline using using the trained modules and save it. |
| 480 | if accelerator.is_main_process: | 491 | if accelerator.is_main_process: |
| @@ -499,7 +510,7 @@ def train( | |||
| 499 | seed: int, | 510 | seed: int, |
| 500 | project: str, | 511 | project: str, |
| 501 | train_dataloader: DataLoader, | 512 | train_dataloader: DataLoader, |
| 502 | val_dataloader: DataLoader, | 513 | val_dataloader: Optional[DataLoader], |
| 503 | optimizer: torch.optim.Optimizer, | 514 | optimizer: torch.optim.Optimizer, |
| 504 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 515 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 505 | callbacks_fn: Callable[..., TrainingCallbacks], | 516 | callbacks_fn: Callable[..., TrainingCallbacks], |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 6e7ebe2..aeaa828 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -26,7 +26,7 @@ def dreambooth_strategy( | |||
| 26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
| 27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
| 28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
| 29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
| 30 | output_dir: Path, | 30 | output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_epochs: int, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 753dce0..568f9eb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -26,7 +26,7 @@ def textual_inversion_strategy( | |||
| 26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
| 27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
| 28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
| 29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
| 30 | output_dir: Path, | 30 | output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
