summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
commit5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch)
treeb1483a52fb853aecb7b73635cded3cce61edf125 /training/strategy
parentFix (diff)
downloadtextual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip
Update
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py12
-rw-r--r--training/strategy/ti.py2
3 files changed, 11 insertions, 5 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0286673..695174a 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 106 with ema_context():
107 yield 107 yield
108 108
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(epoch: int):
110 params_to_clip = [unet.parameters()] 110 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 111 if epoch < train_text_encoder_epochs:
112 params_to_clip.append(text_encoder.parameters()) 112 params_to_clip.append(text_encoder.parameters())
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 912ff26..89269c0 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -79,10 +79,14 @@ def lora_strategy_callbacks(
79 tokenizer.eval() 79 tokenizer.eval()
80 yield 80 yield
81 81
82 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(epoch: int):
83 if not pti_mode: 83 if not pti_mode:
84 accelerator.clip_grad_norm_( 84 accelerator.clip_grad_norm_(
85 itertools.chain(unet.parameters(), text_encoder.parameters()), 85 itertools.chain(
86 unet.parameters(),
87 text_encoder.text_model.encoder.parameters(),
88 text_encoder.text_model.final_layer_norm.parameters(),
89 ),
86 max_grad_norm 90 max_grad_norm
87 ) 91 )
88 92
@@ -95,7 +99,9 @@ def lora_strategy_callbacks(
95 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
96 100
97 @torch.no_grad() 101 @torch.no_grad()
98 def on_after_optimize(w, lr: float): 102 def on_after_optimize(w, lrs: dict[str, float]):
103 lr = lrs["emb"] or lrs["0"]
104
99 if use_emb_decay and w is not None: 105 if use_emb_decay and w is not None:
100 lambda_ = emb_decay * lr 106 lambda_ = emb_decay * lr
101 107
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6a637c3..d735dac 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(lr: float, epoch: int): 107 def on_before_optimize(epoch: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p