summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py6
-rw-r--r--training/strategy/dreambooth.py6
-rw-r--r--training/strategy/lora.py9
-rw-r--r--training/strategy/ti.py6
4 files changed, 8 insertions, 19 deletions
diff --git a/training/functional.py b/training/functional.py
index c30d1c0..4d83df1 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,6 @@ def const(result=None):
34 34
35@dataclass 35@dataclass
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[float, int], Any] = const() 39 on_before_optimize: Callable[[float, int], Any] = const()
@@ -461,7 +460,6 @@ def train_loop(
461 ) 460 )
462 global_progress_bar.set_description("Total progress") 461 global_progress_bar.set_description("Total progress")
463 462
464 model = callbacks.on_accum_model()
465 on_log = callbacks.on_log 463 on_log = callbacks.on_log
466 on_train = callbacks.on_train 464 on_train = callbacks.on_train
467 on_before_optimize = callbacks.on_before_optimize 465 on_before_optimize = callbacks.on_before_optimize
@@ -498,8 +496,6 @@ def train_loop(
498 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 496 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
499 local_progress_bar.reset() 497 local_progress_bar.reset()
500 498
501 model.train()
502
503 with on_train(epoch): 499 with on_train(epoch):
504 for step, batch in enumerate(train_dataloader): 500 for step, batch in enumerate(train_dataloader):
505 loss, acc, bsz = loss_step(step, batch, cache) 501 loss, acc, bsz = loss_step(step, batch, cache)
@@ -560,8 +556,6 @@ def train_loop(
560 on_after_epoch() 556 on_after_epoch()
561 557
562 if val_dataloader is not None: 558 if val_dataloader is not None:
563 model.eval()
564
565 cur_loss_val = AverageMeter() 559 cur_loss_val = AverageMeter()
566 cur_acc_val = AverageMeter() 560 cur_acc_val = AverageMeter()
567 561
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 9808027..0286673 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks(
84 else: 84 else:
85 return nullcontext() 85 return nullcontext()
86 86
87 def on_accum_model():
88 return unet
89
90 @contextmanager 87 @contextmanager
91 def on_train(epoch: int): 88 def on_train(epoch: int):
89 unet.train()
92 tokenizer.train() 90 tokenizer.train()
93 91
94 if epoch < train_text_encoder_epochs: 92 if epoch < train_text_encoder_epochs:
@@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks(
101 99
102 @contextmanager 100 @contextmanager
103 def on_eval(): 101 def on_eval():
102 unet.eval()
104 tokenizer.eval() 103 tokenizer.eval()
105 text_encoder.eval() 104 text_encoder.eval()
106 105
@@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks(
174 torch.cuda.empty_cache() 173 torch.cuda.empty_cache()
175 174
176 return TrainingCallbacks( 175 return TrainingCallbacks(
177 on_accum_model=on_accum_model,
178 on_train=on_train, 176 on_train=on_train,
179 on_eval=on_eval, 177 on_eval=on_eval,
180 on_before_optimize=on_before_optimize, 178 on_before_optimize=on_before_optimize,
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 6730dc9..80ffa9c 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -64,17 +64,17 @@ def lora_strategy_callbacks(
64 image_size=sample_image_size, 64 image_size=sample_image_size,
65 ) 65 )
66 66
67 def on_accum_model():
68 return unet
69
70 @contextmanager 67 @contextmanager
71 def on_train(epoch: int): 68 def on_train(epoch: int):
72 tokenizer.train() 69 unet.train()
73 text_encoder.train() 70 text_encoder.train()
71 tokenizer.train()
74 yield 72 yield
75 73
76 @contextmanager 74 @contextmanager
77 def on_eval(): 75 def on_eval():
76 unet.eval()
77 text_encoder.eval()
78 tokenizer.eval() 78 tokenizer.eval()
79 yield 79 yield
80 80
@@ -152,7 +152,6 @@ def lora_strategy_callbacks(
152 torch.cuda.empty_cache() 152 torch.cuda.empty_cache()
153 153
154 return TrainingCallbacks( 154 return TrainingCallbacks(
155 on_accum_model=on_accum_model,
156 on_train=on_train, 155 on_train=on_train,
157 on_eval=on_eval, 156 on_eval=on_eval,
158 on_before_optimize=on_before_optimize, 157 on_before_optimize=on_before_optimize,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 55e9934..6a637c3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks(
89 else: 89 else:
90 return nullcontext() 90 return nullcontext()
91 91
92 def on_accum_model():
93 return text_encoder.text_model.embeddings.token_override_embedding.params
94
95 @contextmanager 92 @contextmanager
96 def on_train(epoch: int): 93 def on_train(epoch: int):
94 text_encoder.text_model.embeddings.token_override_embedding.params.train()
97 tokenizer.train() 95 tokenizer.train()
98 yield 96 yield
99 97
100 @contextmanager 98 @contextmanager
101 def on_eval(): 99 def on_eval():
100 text_encoder.text_model.embeddings.token_override_embedding.params.eval()
102 tokenizer.eval() 101 tokenizer.eval()
103 102
104 with ema_context(): 103 with ema_context():
@@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks(
166 torch.cuda.empty_cache() 165 torch.cuda.empty_cache()
167 166
168 return TrainingCallbacks( 167 return TrainingCallbacks(
169 on_accum_model=on_accum_model,
170 on_train=on_train, 168 on_train=on_train,
171 on_eval=on_eval, 169 on_eval=on_eval,
172 on_before_optimize=on_before_optimize, 170 on_before_optimize=on_before_optimize,