summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/strategy/lora.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py41
1 files changed, 29 insertions, 12 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index f942b76..14e3384 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -81,7 +81,7 @@ def lora_strategy_callbacks(
81 tokenizer.eval() 81 tokenizer.eval()
82 yield 82 yield
83 83
84 def on_before_optimize(epoch: int): 84 def on_before_optimize(cycle: int):
85 if not pti_mode: 85 if not pti_mode:
86 accelerator.clip_grad_norm_( 86 accelerator.clip_grad_norm_(
87 itertools.chain( 87 itertools.chain(
@@ -89,7 +89,7 @@ def lora_strategy_callbacks(
89 text_encoder.text_model.encoder.parameters(), 89 text_encoder.text_model.encoder.parameters(),
90 text_encoder.text_model.final_layer_norm.parameters(), 90 text_encoder.text_model.final_layer_norm.parameters(),
91 ), 91 ),
92 max_grad_norm 92 max_grad_norm,
93 ) 93 )
94 94
95 if len(placeholder_tokens) != 0 and use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
@@ -108,7 +108,9 @@ def lora_strategy_callbacks(
108 108
109 if lambda_ != 0: 109 if lambda_ != 0:
110 norm = w[:, :].norm(dim=-1, keepdim=True) 110 norm = w[:, :].norm(dim=-1, keepdim=True)
111 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 111 w[:].add_(
112 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
113 )
112 114
113 @torch.no_grad() 115 @torch.no_grad()
114 def on_checkpoint(step, postfix): 116 def on_checkpoint(step, postfix):
@@ -128,25 +130,32 @@ def lora_strategy_callbacks(
128 130
129 if not pti_mode: 131 if not pti_mode:
130 lora_config = {} 132 lora_config = {}
131 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 133 state_dict = get_peft_model_state_dict(
134 unet_, state_dict=accelerator.get_state_dict(unet_)
135 )
132 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 136 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
133 137
134 text_encoder_state_dict = get_peft_model_state_dict( 138 text_encoder_state_dict = get_peft_model_state_dict(
135 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) 139 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
136 ) 140 )
137 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 141 text_encoder_state_dict = {
142 f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()
143 }
138 state_dict.update(text_encoder_state_dict) 144 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 145 lora_config[
146 "text_encoder_peft_config"
147 ] = text_encoder_.get_peft_config_as_dict(inference=True)
140 148
141 if len(placeholder_tokens) != 0: 149 if len(placeholder_tokens) != 0:
142 ti_state_dict = { 150 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) 151 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids) 152 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 } 153 }
147 state_dict.update(ti_state_dict) 154 state_dict.update(ti_state_dict)
148 155
149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 156 save_file(
157 state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors"
158 )
150 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 159 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
151 json.dump(lora_config, f) 160 json.dump(lora_config, f)
152 161
@@ -185,10 +194,18 @@ def lora_prepare(
185 train_dataloader: DataLoader, 194 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 195 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 196 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
188 **kwargs 197 **kwargs,
189): 198):
190 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 199 (
191 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 200 text_encoder,
201 unet,
202 optimizer,
203 train_dataloader,
204 val_dataloader,
205 lr_scheduler,
206 ) = accelerator.prepare(
207 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
208 )
192 209
193 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) 210 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
194 211