summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 22:03:01 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 22:03:01 +0100
commitfc11c86142915d6c3935d28a3321b3ae91b613ef (patch)
tree5d2c84b1ff32e779db868da1248ed24a97cde3c2
parentWIP: Modularization ("free(): invalid pointer" my ass) (diff)
downloadtextual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.gz
textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.tar.bz2
textual-inversion-diff-fc11c86142915d6c3935d28a3321b3ae91b613ef.zip
Update
-rw-r--r--train_ti.py7
-rw-r--r--trainer/base.py2
-rw-r--r--trainer/ti.py2
-rw-r--r--training/functional.py19
4 files changed, 15 insertions, 15 deletions
diff --git a/train_ti.py b/train_ti.py
index deed84c..a4e2dde 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -512,7 +512,7 @@ class TextualInversionCheckpointer(Checkpointer):
512 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 512 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
513 ) 513 )
514 514
515 @torch.inference_mode() 515 @torch.no_grad()
516 def save_samples(self, step): 516 def save_samples(self, step):
517 ema_context = self.ema_embeddings.apply_temporary( 517 ema_context = self.ema_embeddings.apply_temporary(
518 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() 518 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
@@ -808,7 +808,6 @@ def main():
808 optimizer=optimizer, 808 optimizer=optimizer,
809 lr_scheduler=lr_scheduler, 809 lr_scheduler=lr_scheduler,
810 model=text_encoder, 810 model=text_encoder,
811 checkpointer=checkpointer,
812 train_dataloader=train_dataloader, 811 train_dataloader=train_dataloader,
813 val_dataloader=val_dataloader, 812 val_dataloader=val_dataloader,
814 loss_step=loss_step_, 813 loss_step=loss_step_,
@@ -819,7 +818,9 @@ def main():
819 on_log=on_log, 818 on_log=on_log,
820 on_train=on_train, 819 on_train=on_train,
821 on_after_optimize=on_after_optimize, 820 on_after_optimize=on_after_optimize,
822 on_eval=on_eval 821 on_eval=on_eval,
822 on_sample=checkpointer.save_samples,
823 on_checkpoint=checkpointer.checkpoint,
823 ) 824 )
824 825
825 826
diff --git a/trainer/base.py b/trainer/base.py
index e700dd6..1f85e71 100644
--- a/trainer/base.py
+++ b/trainer/base.py
@@ -74,7 +74,7 @@ class Checkpointer():
74 def checkpoint(self, step: int, postfix: str): 74 def checkpoint(self, step: int, postfix: str):
75 pass 75 pass
76 76
77 @torch.inference_mode() 77 @torch.no_grad()
78 def save_samples(self, step: int): 78 def save_samples(self, step: int):
79 print(f"Saving samples for step {step}...") 79 print(f"Saving samples for step {step}...")
80 80
diff --git a/trainer/ti.py b/trainer/ti.py
index 15cf747..388acd3 100644
--- a/trainer/ti.py
+++ b/trainer/ti.py
@@ -42,7 +42,7 @@ class TextualInversionCheckpointer(Checkpointer):
42 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 42 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
43 ) 43 )
44 44
45 @torch.inference_mode() 45 @torch.no_grad()
46 def save_samples(self, step): 46 def save_samples(self, step):
47 ema_context = self.ema_embeddings.apply_temporary( 47 ema_context = self.ema_embeddings.apply_temporary(
48 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() 48 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
diff --git a/training/functional.py b/training/functional.py
index 2d81eca..c100ea2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -17,7 +17,6 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe
17from models.clip.util import get_extended_embeddings 17from models.clip.util import get_extended_embeddings
18from models.clip.tokenizer import MultiCLIPTokenizer 18from models.clip.tokenizer import MultiCLIPTokenizer
19from training.util import AverageMeter 19from training.util import AverageMeter
20from trainer.base import Checkpointer
21 20
22 21
23def const(result=None): 22def const(result=None):
@@ -205,7 +204,6 @@ def train_loop(
205 optimizer: torch.optim.Optimizer, 204 optimizer: torch.optim.Optimizer,
206 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 205 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
207 model: torch.nn.Module, 206 model: torch.nn.Module,
208 checkpointer: Checkpointer,
209 train_dataloader: DataLoader, 207 train_dataloader: DataLoader,
210 val_dataloader: DataLoader, 208 val_dataloader: DataLoader,
211 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 209 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
@@ -217,7 +215,9 @@ def train_loop(
217 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), 215 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
218 on_before_optimize: Callable[[int], None] = const(), 216 on_before_optimize: Callable[[int], None] = const(),
219 on_after_optimize: Callable[[float], None] = const(), 217 on_after_optimize: Callable[[float], None] = const(),
220 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 218 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
219 on_sample: Callable[[int], None] = const(),
220 on_checkpoint: Callable[[int, str], None] = const(),
221): 221):
222 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 222 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
223 num_val_steps_per_epoch = len(val_dataloader) 223 num_val_steps_per_epoch = len(val_dataloader)
@@ -253,10 +253,10 @@ def train_loop(
253 for epoch in range(num_epochs): 253 for epoch in range(num_epochs):
254 if accelerator.is_main_process: 254 if accelerator.is_main_process:
255 if epoch % sample_frequency == 0: 255 if epoch % sample_frequency == 0:
256 checkpointer.save_samples(global_step + global_step_offset) 256 on_sample(global_step + global_step_offset)
257 257
258 if epoch % checkpoint_frequency == 0 and epoch != 0: 258 if epoch % checkpoint_frequency == 0 and epoch != 0:
259 checkpointer.checkpoint(global_step + global_step_offset, "training") 259 on_checkpoint(global_step + global_step_offset, "training")
260 260
261 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 261 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
262 local_progress_bar.reset() 262 local_progress_bar.reset()
@@ -347,19 +347,18 @@ def train_loop(
347 if avg_acc_val.avg.item() > max_acc_val: 347 if avg_acc_val.avg.item() > max_acc_val:
348 accelerator.print( 348 accelerator.print(
349 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 349 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
350 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 350 on_checkpoint(global_step + global_step_offset, "milestone")
351 max_acc_val = avg_acc_val.avg.item() 351 max_acc_val = avg_acc_val.avg.item()
352 352
353 # Create the pipeline using using the trained modules and save it. 353 # Create the pipeline using using the trained modules and save it.
354 if accelerator.is_main_process: 354 if accelerator.is_main_process:
355 print("Finished!") 355 print("Finished!")
356 checkpointer.checkpoint(global_step + global_step_offset, "end") 356 on_checkpoint(global_step + global_step_offset, "end")
357 checkpointer.save_samples(global_step + global_step_offset) 357 on_sample(global_step + global_step_offset)
358 accelerator.end_training() 358 accelerator.end_training()
359 359
360 except KeyboardInterrupt: 360 except KeyboardInterrupt:
361 if accelerator.is_main_process: 361 if accelerator.is_main_process:
362 print("Interrupted") 362 print("Interrupted")
363 checkpointer.checkpoint(global_step + global_step_offset, "end") 363 on_checkpoint(global_step + global_step_offset, "end")
364 accelerator.end_training() 364 accelerator.end_training()
365 quit()