diff options
| -rw-r--r-- | train_ti.py | 7 | ||||
| -rw-r--r-- | trainer/base.py | 2 | ||||
| -rw-r--r-- | trainer/ti.py | 2 | ||||
| -rw-r--r-- | training/functional.py | 19 |
4 files changed, 15 insertions, 15 deletions
diff --git a/train_ti.py b/train_ti.py index deed84c..a4e2dde 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 512 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
| 513 | ) | 513 | ) |
| 514 | 514 | ||
| 515 | @torch.inference_mode() | 515 | @torch.no_grad() |
| 516 | def save_samples(self, step): | 516 | def save_samples(self, step): |
| 517 | ema_context = self.ema_embeddings.apply_temporary( | 517 | ema_context = self.ema_embeddings.apply_temporary( |
| 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 518 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
| @@ -808,7 +808,6 @@ def main(): | |||
| 808 | optimizer=optimizer, | 808 | optimizer=optimizer, |
| 809 | lr_scheduler=lr_scheduler, | 809 | lr_scheduler=lr_scheduler, |
| 810 | model=text_encoder, | 810 | model=text_encoder, |
| 811 | checkpointer=checkpointer, | ||
| 812 | train_dataloader=train_dataloader, | 811 | train_dataloader=train_dataloader, |
| 813 | val_dataloader=val_dataloader, | 812 | val_dataloader=val_dataloader, |
| 814 | loss_step=loss_step_, | 813 | loss_step=loss_step_, |
| @@ -819,7 +818,9 @@ def main(): | |||
| 819 | on_log=on_log, | 818 | on_log=on_log, |
| 820 | on_train=on_train, | 819 | on_train=on_train, |
| 821 | on_after_optimize=on_after_optimize, | 820 | on_after_optimize=on_after_optimize, |
| 822 | on_eval=on_eval | 821 | on_eval=on_eval, |
| 822 | on_sample=checkpointer.save_samples, | ||
| 823 | on_checkpoint=checkpointer.checkpoint, | ||
| 823 | ) | 824 | ) |
| 824 | 825 | ||
| 825 | 826 | ||
diff --git a/trainer/base.py b/trainer/base.py index e700dd6..1f85e71 100644 --- a/trainer/base.py +++ b/trainer/base.py | |||
| @@ -74,7 +74,7 @@ class Checkpointer(): | |||
| 74 | def checkpoint(self, step: int, postfix: str): | 74 | def checkpoint(self, step: int, postfix: str): |
| 75 | pass | 75 | pass |
| 76 | 76 | ||
| 77 | @torch.inference_mode() | 77 | @torch.no_grad() |
| 78 | def save_samples(self, step: int): | 78 | def save_samples(self, step: int): |
| 79 | print(f"Saving samples for step {step}...") | 79 | print(f"Saving samples for step {step}...") |
| 80 | 80 | ||
diff --git a/trainer/ti.py b/trainer/ti.py index 15cf747..388acd3 100644 --- a/trainer/ti.py +++ b/trainer/ti.py | |||
| @@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer): | |||
| 42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 42 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
| 43 | ) | 43 | ) |
| 44 | 44 | ||
| 45 | @torch.inference_mode() | 45 | @torch.no_grad() |
| 46 | def save_samples(self, step): | 46 | def save_samples(self, step): |
| 47 | ema_context = self.ema_embeddings.apply_temporary( | 47 | ema_context = self.ema_embeddings.apply_temporary( |
| 48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | 48 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() |
diff --git a/training/functional.py b/training/functional.py index 2d81eca..c100ea2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
| 17 | from models.clip.util import get_extended_embeddings | 17 | from models.clip.util import get_extended_embeddings |
| 18 | from models.clip.tokenizer import MultiCLIPTokenizer | 18 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 19 | from training.util import AverageMeter | 19 | from training.util import AverageMeter |
| 20 | from trainer.base import Checkpointer | ||
| 21 | 20 | ||
| 22 | 21 | ||
| 23 | def const(result=None): | 22 | def const(result=None): |
| @@ -205,7 +204,6 @@ def train_loop( | |||
| 205 | optimizer: torch.optim.Optimizer, | 204 | optimizer: torch.optim.Optimizer, |
| 206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 205 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 207 | model: torch.nn.Module, | 206 | model: torch.nn.Module, |
| 208 | checkpointer: Checkpointer, | ||
| 209 | train_dataloader: DataLoader, | 207 | train_dataloader: DataLoader, |
| 210 | val_dataloader: DataLoader, | 208 | val_dataloader: DataLoader, |
| 211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 209 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| @@ -217,7 +215,9 @@ def train_loop( | |||
| 217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | 215 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), |
| 218 | on_before_optimize: Callable[[int], None] = const(), | 216 | on_before_optimize: Callable[[int], None] = const(), |
| 219 | on_after_optimize: Callable[[float], None] = const(), | 217 | on_after_optimize: Callable[[float], None] = const(), |
| 220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 218 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()), |
| 219 | on_sample: Callable[[int], None] = const(), | ||
| 220 | on_checkpoint: Callable[[int, str], None] = const(), | ||
| 221 | ): | 221 | ): |
| 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
| 223 | num_val_steps_per_epoch = len(val_dataloader) | 223 | num_val_steps_per_epoch = len(val_dataloader) |
| @@ -253,10 +253,10 @@ def train_loop( | |||
| 253 | for epoch in range(num_epochs): | 253 | for epoch in range(num_epochs): |
| 254 | if accelerator.is_main_process: | 254 | if accelerator.is_main_process: |
| 255 | if epoch % sample_frequency == 0: | 255 | if epoch % sample_frequency == 0: |
| 256 | checkpointer.save_samples(global_step + global_step_offset) | 256 | on_sample(global_step + global_step_offset) |
| 257 | 257 | ||
| 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 259 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 259 | on_checkpoint(global_step + global_step_offset, "training") |
| 260 | 260 | ||
| 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 262 | local_progress_bar.reset() | 262 | local_progress_bar.reset() |
| @@ -347,19 +347,18 @@ def train_loop( | |||
| 347 | if avg_acc_val.avg.item() > max_acc_val: | 347 | if avg_acc_val.avg.item() > max_acc_val: |
| 348 | accelerator.print( | 348 | accelerator.print( |
| 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 350 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 350 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 351 | max_acc_val = avg_acc_val.avg.item() | 351 | max_acc_val = avg_acc_val.avg.item() |
| 352 | 352 | ||
| 353 | # Create the pipeline using using the trained modules and save it. | 353 | # Create the pipeline using using the trained modules and save it. |
| 354 | if accelerator.is_main_process: | 354 | if accelerator.is_main_process: |
| 355 | print("Finished!") | 355 | print("Finished!") |
| 356 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 356 | on_checkpoint(global_step + global_step_offset, "end") |
| 357 | checkpointer.save_samples(global_step + global_step_offset) | 357 | on_sample(global_step + global_step_offset) |
| 358 | accelerator.end_training() | 358 | accelerator.end_training() |
| 359 | 359 | ||
| 360 | except KeyboardInterrupt: | 360 | except KeyboardInterrupt: |
| 361 | if accelerator.is_main_process: | 361 | if accelerator.is_main_process: |
| 362 | print("Interrupted") | 362 | print("Interrupted") |
| 363 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 363 | on_checkpoint(global_step + global_step_offset, "end") |
| 364 | accelerator.end_training() | 364 | accelerator.end_training() |
| 365 | quit() | ||
