diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/lora.py | 12 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 11 insertions, 5 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
106 | with ema_context(): | 106 | with ema_context(): |
107 | yield | 107 | yield |
108 | 108 | ||
109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
79 | tokenizer.eval() | 79 | tokenizer.eval() |
80 | yield | 80 | yield |
81 | 81 | ||
82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
83 | if not pti_mode: | 83 | if not pti_mode: |
84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
86 | unet.parameters(), | ||
87 | text_encoder.text_model.encoder.parameters(), | ||
88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
89 | ), | ||
86 | max_grad_norm | 90 | max_grad_norm |
87 | ) | 91 | ) |
88 | 92 | ||
@@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
96 | 100 | ||
97 | @torch.no_grad() | 101 | @torch.no_grad() |
98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
103 | lr = lrs["emb"] or lrs["0"] | ||
104 | |||
99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
101 | 107 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(lr: float, epoch: int): | 107 | def on_before_optimize(epoch: int): |
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |