summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py6
-rw-r--r--training/strategy/lora.py9
-rw-r--r--training/strategy/ti.py6
3 files changed, 8 insertions, 13 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 9808027..0286673 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks(
84 else: 84 else:
85 return nullcontext() 85 return nullcontext()
86 86
87 def on_accum_model():
88 return unet
89
90 @contextmanager 87 @contextmanager
91 def on_train(epoch: int): 88 def on_train(epoch: int):
89 unet.train()
92 tokenizer.train() 90 tokenizer.train()
93 91
94 if epoch < train_text_encoder_epochs: 92 if epoch < train_text_encoder_epochs:
@@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks(
101 99
102 @contextmanager 100 @contextmanager
103 def on_eval(): 101 def on_eval():
102 unet.eval()
104 tokenizer.eval() 103 tokenizer.eval()
105 text_encoder.eval() 104 text_encoder.eval()
106 105
@@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks(
174 torch.cuda.empty_cache() 173 torch.cuda.empty_cache()
175 174
176 return TrainingCallbacks( 175 return TrainingCallbacks(
177 on_accum_model=on_accum_model,
178 on_train=on_train, 176 on_train=on_train,
179 on_eval=on_eval, 177 on_eval=on_eval,
180 on_before_optimize=on_before_optimize, 178 on_before_optimize=on_before_optimize,
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 6730dc9..80ffa9c 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -64,17 +64,17 @@ def lora_strategy_callbacks(
64 image_size=sample_image_size, 64 image_size=sample_image_size,
65 ) 65 )
66 66
67 def on_accum_model():
68 return unet
69
70 @contextmanager 67 @contextmanager
71 def on_train(epoch: int): 68 def on_train(epoch: int):
72 tokenizer.train() 69 unet.train()
73 text_encoder.train() 70 text_encoder.train()
71 tokenizer.train()
74 yield 72 yield
75 73
76 @contextmanager 74 @contextmanager
77 def on_eval(): 75 def on_eval():
76 unet.eval()
77 text_encoder.eval()
78 tokenizer.eval() 78 tokenizer.eval()
79 yield 79 yield
80 80
@@ -152,7 +152,6 @@ def lora_strategy_callbacks(
152 torch.cuda.empty_cache() 152 torch.cuda.empty_cache()
153 153
154 return TrainingCallbacks( 154 return TrainingCallbacks(
155 on_accum_model=on_accum_model,
156 on_train=on_train, 155 on_train=on_train,
157 on_eval=on_eval, 156 on_eval=on_eval,
158 on_before_optimize=on_before_optimize, 157 on_before_optimize=on_before_optimize,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 55e9934..6a637c3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks(
89 else: 89 else:
90 return nullcontext() 90 return nullcontext()
91 91
92 def on_accum_model():
93 return text_encoder.text_model.embeddings.token_override_embedding.params
94
95 @contextmanager 92 @contextmanager
96 def on_train(epoch: int): 93 def on_train(epoch: int):
94 text_encoder.text_model.embeddings.token_override_embedding.params.train()
97 tokenizer.train() 95 tokenizer.train()
98 yield 96 yield
99 97
100 @contextmanager 98 @contextmanager
101 def on_eval(): 99 def on_eval():
100 text_encoder.text_model.embeddings.token_override_embedding.params.eval()
102 tokenizer.eval() 101 tokenizer.eval()
103 102
104 with ema_context(): 103 with ema_context():
@@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks(
166 torch.cuda.empty_cache() 165 torch.cuda.empty_cache()
167 166
168 return TrainingCallbacks( 167 return TrainingCallbacks(
169 on_accum_model=on_accum_model,
170 on_train=on_train, 168 on_train=on_train,
171 on_eval=on_eval, 169 on_eval=on_eval,
172 on_before_optimize=on_before_optimize, 170 on_before_optimize=on_before_optimize,