diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 31 |
1 files changed, 11 insertions, 20 deletions
diff --git a/training/functional.py b/training/functional.py index c373ac9..8f47734 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -34,7 +34,7 @@ def const(result=None): | |||
| 34 | @dataclass | 34 | @dataclass |
| 35 | class TrainingCallbacks(): | 35 | class TrainingCallbacks(): |
| 36 | on_prepare: Callable[[], None] = const() | 36 | on_prepare: Callable[[], None] = const() |
| 37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
| 38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[float, int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
| @@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
| 51 | accelerator: Accelerator, | 51 | accelerator: Accelerator, |
| 52 | text_encoder: CLIPTextModel, | 52 | text_encoder: CLIPTextModel, |
| 53 | unet: UNet2DConditionModel, | 53 | unet: UNet2DConditionModel, |
| 54 | *args | 54 | optimizer: torch.optim.Optimizer, |
| 55 | train_dataloader: DataLoader, | ||
| 56 | val_dataloader: Optional[DataLoader], | ||
| 57 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 58 | **kwargs | ||
| 55 | ) -> Tuple: ... | 59 | ) -> Tuple: ... |
| 56 | 60 | ||
| 57 | 61 | ||
| @@ -92,7 +96,6 @@ def save_samples( | |||
| 92 | sample_scheduler: DPMSolverMultistepScheduler, | 96 | sample_scheduler: DPMSolverMultistepScheduler, |
| 93 | train_dataloader: DataLoader, | 97 | train_dataloader: DataLoader, |
| 94 | val_dataloader: Optional[DataLoader], | 98 | val_dataloader: Optional[DataLoader], |
| 95 | dtype: torch.dtype, | ||
| 96 | output_dir: Path, | 99 | output_dir: Path, |
| 97 | seed: int, | 100 | seed: int, |
| 98 | step: int, | 101 | step: int, |
| @@ -107,15 +110,6 @@ def save_samples( | |||
| 107 | grid_cols = min(batch_size, 4) | 110 | grid_cols = min(batch_size, 4) |
| 108 | grid_rows = (num_batches * batch_size) // grid_cols | 111 | grid_rows = (num_batches * batch_size) // grid_cols |
| 109 | 112 | ||
| 110 | unet = accelerator.unwrap_model(unet) | ||
| 111 | text_encoder = accelerator.unwrap_model(text_encoder) | ||
| 112 | |||
| 113 | orig_unet_dtype = unet.dtype | ||
| 114 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 115 | |||
| 116 | unet.to(dtype=dtype) | ||
| 117 | text_encoder.to(dtype=dtype) | ||
| 118 | |||
| 119 | pipeline = VlpnStableDiffusion( | 113 | pipeline = VlpnStableDiffusion( |
| 120 | text_encoder=text_encoder, | 114 | text_encoder=text_encoder, |
| 121 | vae=vae, | 115 | vae=vae, |
| @@ -172,11 +166,6 @@ def save_samples( | |||
| 172 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 166 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
| 173 | image_grid.save(file_path, quality=85) | 167 | image_grid.save(file_path, quality=85) |
| 174 | 168 | ||
| 175 | unet.to(dtype=orig_unet_dtype) | ||
| 176 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
| 177 | |||
| 178 | del unet | ||
| 179 | del text_encoder | ||
| 180 | del generator | 169 | del generator |
| 181 | del pipeline | 170 | del pipeline |
| 182 | 171 | ||
| @@ -393,7 +382,7 @@ def train_loop( | |||
| 393 | ) | 382 | ) |
| 394 | global_progress_bar.set_description("Total progress") | 383 | global_progress_bar.set_description("Total progress") |
| 395 | 384 | ||
| 396 | model = callbacks.on_model() | 385 | model = callbacks.on_accum_model() |
| 397 | on_log = callbacks.on_log | 386 | on_log = callbacks.on_log |
| 398 | on_train = callbacks.on_train | 387 | on_train = callbacks.on_train |
| 399 | on_before_optimize = callbacks.on_before_optimize | 388 | on_before_optimize = callbacks.on_before_optimize |
| @@ -559,8 +548,10 @@ def train( | |||
| 559 | prior_loss_weight: float = 1.0, | 548 | prior_loss_weight: float = 1.0, |
| 560 | **kwargs, | 549 | **kwargs, |
| 561 | ): | 550 | ): |
| 562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 551 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| 563 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 552 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) |
| 553 | |||
| 554 | kwargs.update(extra) | ||
| 564 | 555 | ||
| 565 | vae.to(accelerator.device, dtype=dtype) | 556 | vae.to(accelerator.device, dtype=dtype) |
| 566 | 557 | ||
