summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
committerVolpeon <git@volpeon.ink>2023-04-03 22:25:20 +0200
commit2e654c017780d37f3304436e2feb84b619f1c023 (patch)
tree8a248fe17c3512110de9fcfed7f7bfd708b3b8da /training
parentTI: Delta learning (diff)
downloadtextual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.gz
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.tar.bz2
textual-inversion-diff-2e654c017780d37f3304436e2feb84b619f1c023.zip
Improved sparse embeddings
Diffstat (limited to 'training')
-rw-r--r--training/strategy/ti.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 16baa34..95128da 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -69,7 +69,7 @@ def textual_inversion_strategy_callbacks(
69 69
70 if use_ema: 70 if use_ema:
71 ema_embeddings = EMAModel( 71 ema_embeddings = EMAModel(
72 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 72 text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
73 inv_gamma=ema_inv_gamma, 73 inv_gamma=ema_inv_gamma,
74 power=ema_power, 74 power=ema_power,
75 max_value=ema_max_decay, 75 max_value=ema_max_decay,
@@ -81,13 +81,13 @@ def textual_inversion_strategy_callbacks(
81 def ema_context(): 81 def ema_context():
82 if ema_embeddings is not None: 82 if ema_embeddings is not None:
83 return ema_embeddings.apply_temporary( 83 return ema_embeddings.apply_temporary(
84 text_encoder.text_model.embeddings.temp_token_embedding.parameters() 84 text_encoder.text_model.embeddings.token_override_embedding.params.parameters()
85 ) 85 )
86 else: 86 else:
87 return nullcontext() 87 return nullcontext()
88 88
89 def on_accum_model(): 89 def on_accum_model():
90 return text_encoder.text_model.embeddings.temp_token_embedding 90 return text_encoder.text_model.embeddings.token_override_embedding.params
91 91
92 @contextmanager 92 @contextmanager
93 def on_train(epoch: int): 93 def on_train(epoch: int):
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 @torch.no_grad() 104 @torch.no_grad()
105 def on_after_optimize(zero_ids, lr: float): 105 def on_after_optimize(zero_ids, lr: float):
106 if ema_embeddings is not None: 106 if ema_embeddings is not None:
107 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 107 ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters())
108 108
109 def on_log(): 109 def on_log():
110 if ema_embeddings is not None: 110 if ema_embeddings is not None: