diff options
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/lora.py | 12 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 11 insertions, 5 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0286673..695174a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 106 | with ema_context(): |
| 107 | yield | 107 | yield |
| 108 | 108 | ||
| 109 | def on_before_optimize(lr: float, epoch: int): | 109 | def on_before_optimize(epoch: int): |
| 110 | params_to_clip = [unet.parameters()] | 110 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 111 | if epoch < train_text_encoder_epochs: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 112 | params_to_clip.append(text_encoder.parameters()) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 912ff26..89269c0 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -79,10 +79,14 @@ def lora_strategy_callbacks( | |||
| 79 | tokenizer.eval() | 79 | tokenizer.eval() |
| 80 | yield | 80 | yield |
| 81 | 81 | ||
| 82 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(epoch: int): |
| 83 | if not pti_mode: | 83 | if not pti_mode: |
| 84 | accelerator.clip_grad_norm_( | 84 | accelerator.clip_grad_norm_( |
| 85 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 85 | itertools.chain( |
| 86 | unet.parameters(), | ||
| 87 | text_encoder.text_model.encoder.parameters(), | ||
| 88 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 89 | ), | ||
| 86 | max_grad_norm | 90 | max_grad_norm |
| 87 | ) | 91 | ) |
| 88 | 92 | ||
| @@ -95,7 +99,9 @@ def lora_strategy_callbacks( | |||
| 95 | return torch.stack(params) if len(params) != 0 else None | 99 | return torch.stack(params) if len(params) != 0 else None |
| 96 | 100 | ||
| 97 | @torch.no_grad() | 101 | @torch.no_grad() |
| 98 | def on_after_optimize(w, lr: float): | 102 | def on_after_optimize(w, lrs: dict[str, float]): |
| 103 | lr = lrs["emb"] or lrs["0"] | ||
| 104 | |||
| 99 | if use_emb_decay and w is not None: | 105 | if use_emb_decay and w is not None: |
| 100 | lambda_ = emb_decay * lr | 106 | lambda_ = emb_decay * lr |
| 101 | 107 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6a637c3..d735dac 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(lr: float, epoch: int): | 107 | def on_before_optimize(epoch: int): |
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
