diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 101 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 58 insertions, 47 deletions
diff --git a/training/functional.py b/training/functional.py index 1b6162b..c6b4dc3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -73,7 +73,7 @@ def save_samples( | |||
73 | vae: AutoencoderKL, | 73 | vae: AutoencoderKL, |
74 | sample_scheduler: DPMSolverMultistepScheduler, | 74 | sample_scheduler: DPMSolverMultistepScheduler, |
75 | train_dataloader: DataLoader, | 75 | train_dataloader: DataLoader, |
76 | val_dataloader: DataLoader, | 76 | val_dataloader: Optional[DataLoader], |
77 | dtype: torch.dtype, | 77 | dtype: torch.dtype, |
78 | output_dir: Path, | 78 | output_dir: Path, |
79 | seed: int, | 79 | seed: int, |
@@ -111,11 +111,13 @@ def save_samples( | |||
111 | 111 | ||
112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
113 | 113 | ||
114 | for pool, data, gen in [ | 114 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] |
115 | ("stable", val_dataloader, generator), | 115 | |
116 | ("val", val_dataloader, None), | 116 | if val_dataloader is not None: |
117 | ("train", train_dataloader, None) | 117 | datasets.append(("stable", val_dataloader, generator)) |
118 | ]: | 118 | datasets.append(("val", val_dataloader, None)) |
119 | |||
120 | for pool, data, gen in datasets: | ||
119 | all_samples = [] | 121 | all_samples = [] |
120 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 122 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
121 | file_path.parent.mkdir(parents=True, exist_ok=True) | 123 | file_path.parent.mkdir(parents=True, exist_ok=True) |
@@ -328,7 +330,7 @@ def train_loop( | |||
328 | optimizer: torch.optim.Optimizer, | 330 | optimizer: torch.optim.Optimizer, |
329 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 331 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
330 | train_dataloader: DataLoader, | 332 | train_dataloader: DataLoader, |
331 | val_dataloader: DataLoader, | 333 | val_dataloader: Optional[DataLoader], |
332 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 334 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
333 | sample_frequency: int = 10, | 335 | sample_frequency: int = 10, |
334 | checkpoint_frequency: int = 50, | 336 | checkpoint_frequency: int = 50, |
@@ -337,7 +339,7 @@ def train_loop( | |||
337 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 339 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
338 | ): | 340 | ): |
339 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 341 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
340 | num_val_steps_per_epoch = len(val_dataloader) | 342 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 |
341 | 343 | ||
342 | num_training_steps = num_training_steps_per_epoch * num_epochs | 344 | num_training_steps = num_training_steps_per_epoch * num_epochs |
343 | num_val_steps = num_val_steps_per_epoch * num_epochs | 345 | num_val_steps = num_val_steps_per_epoch * num_epochs |
@@ -350,6 +352,7 @@ def train_loop( | |||
350 | avg_loss_val = AverageMeter() | 352 | avg_loss_val = AverageMeter() |
351 | avg_acc_val = AverageMeter() | 353 | avg_acc_val = AverageMeter() |
352 | 354 | ||
355 | max_acc = 0.0 | ||
353 | max_acc_val = 0.0 | 356 | max_acc_val = 0.0 |
354 | 357 | ||
355 | local_progress_bar = tqdm( | 358 | local_progress_bar = tqdm( |
@@ -432,49 +435,57 @@ def train_loop( | |||
432 | 435 | ||
433 | accelerator.wait_for_everyone() | 436 | accelerator.wait_for_everyone() |
434 | 437 | ||
435 | model.eval() | 438 | if val_dataloader is not None: |
436 | 439 | model.eval() | |
437 | cur_loss_val = AverageMeter() | ||
438 | cur_acc_val = AverageMeter() | ||
439 | |||
440 | with torch.inference_mode(), on_eval(): | ||
441 | for step, batch in enumerate(val_dataloader): | ||
442 | loss, acc, bsz = loss_step(step, batch, True) | ||
443 | |||
444 | loss = loss.detach_() | ||
445 | acc = acc.detach_() | ||
446 | |||
447 | cur_loss_val.update(loss, bsz) | ||
448 | cur_acc_val.update(acc, bsz) | ||
449 | 440 | ||
450 | avg_loss_val.update(loss, bsz) | 441 | cur_loss_val = AverageMeter() |
451 | avg_acc_val.update(acc, bsz) | 442 | cur_acc_val = AverageMeter() |
452 | 443 | ||
453 | local_progress_bar.update(1) | 444 | with torch.inference_mode(), on_eval(): |
454 | global_progress_bar.update(1) | 445 | for step, batch in enumerate(val_dataloader): |
446 | loss, acc, bsz = loss_step(step, batch, True) | ||
455 | 447 | ||
456 | logs = { | 448 | loss = loss.detach_() |
457 | "val/loss": avg_loss_val.avg.item(), | 449 | acc = acc.detach_() |
458 | "val/acc": avg_acc_val.avg.item(), | ||
459 | "val/cur_loss": loss.item(), | ||
460 | "val/cur_acc": acc.item(), | ||
461 | } | ||
462 | local_progress_bar.set_postfix(**logs) | ||
463 | 450 | ||
464 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 451 | cur_loss_val.update(loss, bsz) |
465 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 452 | cur_acc_val.update(acc, bsz) |
466 | 453 | ||
467 | accelerator.log(logs, step=global_step) | 454 | avg_loss_val.update(loss, bsz) |
455 | avg_acc_val.update(acc, bsz) | ||
468 | 456 | ||
469 | local_progress_bar.clear() | 457 | local_progress_bar.update(1) |
470 | global_progress_bar.clear() | 458 | global_progress_bar.update(1) |
471 | 459 | ||
472 | if accelerator.is_main_process: | 460 | logs = { |
473 | if avg_acc_val.avg.item() > max_acc_val: | 461 | "val/loss": avg_loss_val.avg.item(), |
474 | accelerator.print( | 462 | "val/acc": avg_acc_val.avg.item(), |
475 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 463 | "val/cur_loss": loss.item(), |
476 | on_checkpoint(global_step + global_step_offset, "milestone") | 464 | "val/cur_acc": acc.item(), |
477 | max_acc_val = avg_acc_val.avg.item() | 465 | } |
466 | local_progress_bar.set_postfix(**logs) | ||
467 | |||
468 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
469 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
470 | |||
471 | accelerator.log(logs, step=global_step) | ||
472 | |||
473 | local_progress_bar.clear() | ||
474 | global_progress_bar.clear() | ||
475 | |||
476 | if accelerator.is_main_process: | ||
477 | if avg_acc_val.avg.item() > max_acc_val: | ||
478 | accelerator.print( | ||
479 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
480 | on_checkpoint(global_step + global_step_offset, "milestone") | ||
481 | max_acc_val = avg_acc_val.avg.item() | ||
482 | else: | ||
483 | if accelerator.is_main_process: | ||
484 | if avg_acc.avg.item() > max_acc: | ||
485 | accelerator.print( | ||
486 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | ||
487 | on_checkpoint(global_step + global_step_offset, "milestone") | ||
488 | max_acc = avg_acc.avg.item() | ||
478 | 489 | ||
479 | # Create the pipeline using using the trained modules and save it. | 490 | # Create the pipeline using using the trained modules and save it. |
480 | if accelerator.is_main_process: | 491 | if accelerator.is_main_process: |
@@ -499,7 +510,7 @@ def train( | |||
499 | seed: int, | 510 | seed: int, |
500 | project: str, | 511 | project: str, |
501 | train_dataloader: DataLoader, | 512 | train_dataloader: DataLoader, |
502 | val_dataloader: DataLoader, | 513 | val_dataloader: Optional[DataLoader], |
503 | optimizer: torch.optim.Optimizer, | 514 | optimizer: torch.optim.Optimizer, |
504 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 515 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
505 | callbacks_fn: Callable[..., TrainingCallbacks], | 516 | callbacks_fn: Callable[..., TrainingCallbacks], |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 6e7ebe2..aeaa828 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -26,7 +26,7 @@ def dreambooth_strategy( | |||
26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | output_dir: Path, |
31 | seed: int, | 31 | seed: int, |
32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_epochs: int, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 753dce0..568f9eb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -26,7 +26,7 @@ def textual_inversion_strategy( | |||
26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | output_dir: Path, |
31 | seed: int, | 31 | seed: int, |
32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |