diff options
-rw-r--r-- | train_ti.py | 7 | ||||
-rw-r--r-- | trainer/base.py | 2 | ||||
-rw-r--r-- | trainer/ti.py | 2 | ||||
-rw-r--r-- | training/functional.py | 19 |
4 files changed, 15 insertions, 15 deletions
diff --git a/train_ti.py b/train_ti.py index deed84c..a4e2dde 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
513 | ) | 513 | ) |
514 | 514 | ||
515 | @torch.inference_mode() | 515 | @torch.no_grad() |
516 | def save_samples(self, step): | 516 | def save_samples(self, step): |
517 | ema_context = self.ema_embeddings.apply_temporary( | 517 | ema_context = self.ema_embeddings.apply_temporary( |
518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
@@ -808,7 +808,6 @@ def main(): | |||
808 | optimizer=optimizer, | 808 | optimizer=optimizer, |
809 | lr_scheduler=lr_scheduler, | 809 | lr_scheduler=lr_scheduler, |
810 | model=text_encoder, | 810 | model=text_encoder, |
811 | checkpointer=checkpointer, | ||
812 | train_dataloader=train_dataloader, | 811 | train_dataloader=train_dataloader, |
813 | val_dataloader=val_dataloader, | 812 | val_dataloader=val_dataloader, |
814 | loss_step=loss_step_, | 813 | loss_step=loss_step_, |
@@ -819,7 +818,9 @@ def main(): | |||
819 | on_log=on_log, | 818 | on_log=on_log, |
820 | on_train=on_train, | 819 | on_train=on_train, |
821 | on_after_optimize=on_after_optimize, | 820 | on_after_optimize=on_after_optimize, |
822 | on_eval=on_eval | 821 | on_eval=on_eval, |
822 | on_sample=checkpointer.save_samples, | ||
823 | on_checkpoint=checkpointer.checkpoint, | ||
823 | ) | 824 | ) |
824 | 825 | ||
825 | 826 | ||
diff --git a/trainer/base.py b/trainer/base.py index e700dd6..1f85e71 100644 --- a/trainer/base.py +++ b/trainer/base.py | |||
@@ -74,7 +74,7 @@ class Checkpointer(): | |||
74 | def checkpoint(self, step: int, postfix: str): | 74 | def checkpoint(self, step: int, postfix: str): |
75 | pass | 75 | pass |
76 | 76 | ||
77 | @torch.inference_mode() | 77 | @torch.no_grad() |
78 | def save_samples(self, step: int): | 78 | def save_samples(self, step: int): |
79 | print(f"Saving samples for step {step}...") | 79 | print(f"Saving samples for step {step}...") |
80 | 80 | ||
diff --git a/trainer/ti.py b/trainer/ti.py index 15cf747..388acd3 100644 --- a/trainer/ti.py +++ b/trainer/ti.py | |||
@@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
43 | ) | 43 | ) |
44 | 44 | ||
45 | @torch.inference_mode() | 45 | @torch.no_grad() |
46 | def save_samples(self, step): | 46 | def save_samples(self, step): |
47 | ema_context = self.ema_embeddings.apply_temporary( | 47 | ema_context = self.ema_embeddings.apply_temporary( |
48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
diff --git a/training/functional.py b/training/functional.py index 2d81eca..c100ea2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
17 | from models.clip.util import get_extended_embeddings | 17 | from models.clip.util import get_extended_embeddings |
18 | from models.clip.tokenizer import MultiCLIPTokenizer | 18 | from models.clip.tokenizer import MultiCLIPTokenizer |
19 | from training.util import AverageMeter | 19 | from training.util import AverageMeter |
20 | from trainer.base import Checkpointer | ||
21 | 20 | ||
22 | 21 | ||
23 | def const(result=None): | 22 | def const(result=None): |
@@ -205,7 +204,6 @@ def train_loop( | |||
205 | optimizer: torch.optim.Optimizer, | 204 | optimizer: torch.optim.Optimizer, |
206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 205 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
207 | model: torch.nn.Module, | 206 | model: torch.nn.Module, |
208 | checkpointer: Checkpointer, | ||
209 | train_dataloader: DataLoader, | 207 | train_dataloader: DataLoader, |
210 | val_dataloader: DataLoader, | 208 | val_dataloader: DataLoader, |
211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 209 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
@@ -217,7 +215,9 @@ def train_loop( | |||
217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | 215 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), |
218 | on_before_optimize: Callable[[int], None] = const(), | 216 | on_before_optimize: Callable[[int], None] = const(), |
219 | on_after_optimize: Callable[[float], None] = const(), | 217 | on_after_optimize: Callable[[float], None] = const(), |
220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 218 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), |
219 | on_sample: Callable[[int], None] = const(), | ||
220 | on_checkpoint: Callable[[int, str], None] = const(), | ||
221 | ): | 221 | ): |
222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
223 | num_val_steps_per_epoch = len(val_dataloader) | 223 | num_val_steps_per_epoch = len(val_dataloader) |
@@ -253,10 +253,10 @@ def train_loop( | |||
253 | for epoch in range(num_epochs): | 253 | for epoch in range(num_epochs): |
254 | if accelerator.is_main_process: | 254 | if accelerator.is_main_process: |
255 | if epoch % sample_frequency == 0: | 255 | if epoch % sample_frequency == 0: |
256 | checkpointer.save_samples(global_step + global_step_offset) | 256 | on_sample(global_step + global_step_offset) |
257 | 257 | ||
258 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
259 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 259 | on_checkpoint(global_step + global_step_offset, "training") |
260 | 260 | ||
261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
262 | local_progress_bar.reset() | 262 | local_progress_bar.reset() |
@@ -347,19 +347,18 @@ def train_loop( | |||
347 | if avg_acc_val.avg.item() > max_acc_val: | 347 | if avg_acc_val.avg.item() > max_acc_val: |
348 | accelerator.print( | 348 | accelerator.print( |
349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
350 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 350 | on_checkpoint(global_step + global_step_offset, "milestone") |
351 | max_acc_val = avg_acc_val.avg.item() | 351 | max_acc_val = avg_acc_val.avg.item() |
352 | 352 | ||
353 | # Create the pipeline using using the trained modules and save it. | 353 | # Create the pipeline using using the trained modules and save it. |
354 | if accelerator.is_main_process: | 354 | if accelerator.is_main_process: |
355 | print("Finished!") | 355 | print("Finished!") |
356 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 356 | on_checkpoint(global_step + global_step_offset, "end") |
357 | checkpointer.save_samples(global_step + global_step_offset) | 357 | on_sample(global_step + global_step_offset) |
358 | accelerator.end_training() | 358 | accelerator.end_training() |
359 | 359 | ||
360 | except KeyboardInterrupt: | 360 | except KeyboardInterrupt: |
361 | if accelerator.is_main_process: | 361 | if accelerator.is_main_process: |
362 | print("Interrupted") | 362 | print("Interrupted") |
363 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 363 | on_checkpoint(global_step + global_step_offset, "end") |
364 | accelerator.end_training() | 364 | accelerator.end_training() |
365 | quit() | ||