diff options
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 41 |
1 files changed, 29 insertions, 12 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -81,7 +81,7 @@ def lora_strategy_callbacks( | |||
81 | tokenizer.eval() | 81 | tokenizer.eval() |
82 | yield | 82 | yield |
83 | 83 | ||
84 | def on_before_optimize(epoch: int): | 84 | def on_before_optimize(cycle: int): |
85 | if not pti_mode: | 85 | if not pti_mode: |
86 | accelerator.clip_grad_norm_( | 86 | accelerator.clip_grad_norm_( |
87 | itertools.chain( | 87 | itertools.chain( |
@@ -89,7 +89,7 @@ def lora_strategy_callbacks( | |||
89 | text_encoder.text_model.encoder.parameters(), | 89 | text_encoder.text_model.encoder.parameters(), |
90 | text_encoder.text_model.final_layer_norm.parameters(), | 90 | text_encoder.text_model.final_layer_norm.parameters(), |
91 | ), | 91 | ), |
92 | max_grad_norm | 92 | max_grad_norm, |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
@@ -108,7 +108,9 @@ def lora_strategy_callbacks( | |||
108 | 108 | ||
109 | if lambda_ != 0: | 109 | if lambda_ != 0: |
110 | norm = w[:, :].norm(dim=-1, keepdim=True) | 110 | norm = w[:, :].norm(dim=-1, keepdim=True) |
111 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 111 | w[:].add_( |
112 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
113 | ) | ||
112 | 114 | ||
113 | @torch.no_grad() | 115 | @torch.no_grad() |
114 | def on_checkpoint(step, postfix): | 116 | def on_checkpoint(step, postfix): |
@@ -128,25 +130,32 @@ def lora_strategy_callbacks( | |||
128 | 130 | ||
129 | if not pti_mode: | 131 | if not pti_mode: |
130 | lora_config = {} | 132 | lora_config = {} |
131 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 133 | state_dict = get_peft_model_state_dict( |
134 | unet_, state_dict=accelerator.get_state_dict(unet_) | ||
135 | ) | ||
132 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 136 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
133 | 137 | ||
134 | text_encoder_state_dict = get_peft_model_state_dict( | 138 | text_encoder_state_dict = get_peft_model_state_dict( |
135 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 139 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
136 | ) | 140 | ) |
137 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 141 | text_encoder_state_dict = { |
142 | f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() | ||
143 | } | ||
138 | state_dict.update(text_encoder_state_dict) | 144 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 145 | lora_config[ |
146 | "text_encoder_peft_config" | ||
147 | ] = text_encoder_.get_peft_config_as_dict(inference=True) | ||
140 | 148 | ||
141 | if len(placeholder_tokens) != 0: | 149 | if len(placeholder_tokens) != 0: |
142 | ti_state_dict = { | 150 | ti_state_dict = { |
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | 151 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) |
144 | for (token, ids) | 152 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) |
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | 153 | } |
147 | state_dict.update(ti_state_dict) | 154 | state_dict.update(ti_state_dict) |
148 | 155 | ||
149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 156 | save_file( |
157 | state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" | ||
158 | ) | ||
150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 159 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
151 | json.dump(lora_config, f) | 160 | json.dump(lora_config, f) |
152 | 161 | ||
@@ -185,10 +194,18 @@ def lora_prepare( | |||
185 | train_dataloader: DataLoader, | 194 | train_dataloader: DataLoader, |
186 | val_dataloader: Optional[DataLoader], | 195 | val_dataloader: Optional[DataLoader], |
187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 196 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
188 | **kwargs | 197 | **kwargs, |
189 | ): | 198 | ): |
190 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 199 | ( |
191 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 200 | text_encoder, |
201 | unet, | ||
202 | optimizer, | ||
203 | train_dataloader, | ||
204 | val_dataloader, | ||
205 | lr_scheduler, | ||
206 | ) = accelerator.prepare( | ||
207 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
208 | ) | ||
192 | 209 | ||
193 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) | 210 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
194 | 211 | ||