diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 31 |
1 files changed, 11 insertions, 20 deletions
diff --git a/training/functional.py b/training/functional.py index c373ac9..8f47734 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -34,7 +34,7 @@ def const(result=None): | |||
34 | @dataclass | 34 | @dataclass |
35 | class TrainingCallbacks(): | 35 | class TrainingCallbacks(): |
36 | on_prepare: Callable[[], None] = const() | 36 | on_prepare: Callable[[], None] = const() |
37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[float, int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
@@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
51 | accelerator: Accelerator, | 51 | accelerator: Accelerator, |
52 | text_encoder: CLIPTextModel, | 52 | text_encoder: CLIPTextModel, |
53 | unet: UNet2DConditionModel, | 53 | unet: UNet2DConditionModel, |
54 | *args | 54 | optimizer: torch.optim.Optimizer, |
55 | train_dataloader: DataLoader, | ||
56 | val_dataloader: Optional[DataLoader], | ||
57 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
58 | **kwargs | ||
55 | ) -> Tuple: ... | 59 | ) -> Tuple: ... |
56 | 60 | ||
57 | 61 | ||
@@ -92,7 +96,6 @@ def save_samples( | |||
92 | sample_scheduler: DPMSolverMultistepScheduler, | 96 | sample_scheduler: DPMSolverMultistepScheduler, |
93 | train_dataloader: DataLoader, | 97 | train_dataloader: DataLoader, |
94 | val_dataloader: Optional[DataLoader], | 98 | val_dataloader: Optional[DataLoader], |
95 | dtype: torch.dtype, | ||
96 | output_dir: Path, | 99 | output_dir: Path, |
97 | seed: int, | 100 | seed: int, |
98 | step: int, | 101 | step: int, |
@@ -107,15 +110,6 @@ def save_samples( | |||
107 | grid_cols = min(batch_size, 4) | 110 | grid_cols = min(batch_size, 4) |
108 | grid_rows = (num_batches * batch_size) // grid_cols | 111 | grid_rows = (num_batches * batch_size) // grid_cols |
109 | 112 | ||
110 | unet = accelerator.unwrap_model(unet) | ||
111 | text_encoder = accelerator.unwrap_model(text_encoder) | ||
112 | |||
113 | orig_unet_dtype = unet.dtype | ||
114 | orig_text_encoder_dtype = text_encoder.dtype | ||
115 | |||
116 | unet.to(dtype=dtype) | ||
117 | text_encoder.to(dtype=dtype) | ||
118 | |||
119 | pipeline = VlpnStableDiffusion( | 113 | pipeline = VlpnStableDiffusion( |
120 | text_encoder=text_encoder, | 114 | text_encoder=text_encoder, |
121 | vae=vae, | 115 | vae=vae, |
@@ -172,11 +166,6 @@ def save_samples( | |||
172 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 166 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
173 | image_grid.save(file_path, quality=85) | 167 | image_grid.save(file_path, quality=85) |
174 | 168 | ||
175 | unet.to(dtype=orig_unet_dtype) | ||
176 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
177 | |||
178 | del unet | ||
179 | del text_encoder | ||
180 | del generator | 169 | del generator |
181 | del pipeline | 170 | del pipeline |
182 | 171 | ||
@@ -393,7 +382,7 @@ def train_loop( | |||
393 | ) | 382 | ) |
394 | global_progress_bar.set_description("Total progress") | 383 | global_progress_bar.set_description("Total progress") |
395 | 384 | ||
396 | model = callbacks.on_model() | 385 | model = callbacks.on_accum_model() |
397 | on_log = callbacks.on_log | 386 | on_log = callbacks.on_log |
398 | on_train = callbacks.on_train | 387 | on_train = callbacks.on_train |
399 | on_before_optimize = callbacks.on_before_optimize | 388 | on_before_optimize = callbacks.on_before_optimize |
@@ -559,8 +548,10 @@ def train( | |||
559 | prior_loss_weight: float = 1.0, | 548 | prior_loss_weight: float = 1.0, |
560 | **kwargs, | 549 | **kwargs, |
561 | ): | 550 | ): |
562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 551 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
563 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 552 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) |
553 | |||
554 | kwargs.update(extra) | ||
564 | 555 | ||
565 | vae.to(accelerator.device, dtype=dtype) | 556 | vae.to(accelerator.device, dtype=dtype) |
566 | 557 | ||