summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/strategy
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/dreambooth.py29
-rw-r--r--training/strategy/lora.py41
-rw-r--r--training/strategy/ti.py27
3 files changed, 65 insertions, 32 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e6fcc89..88b441b 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks(
29 sample_output_dir: Path, 29 sample_output_dir: Path,
30 checkpoint_output_dir: Path, 30 checkpoint_output_dir: Path,
31 seed: int, 31 seed: int,
32 train_text_encoder_epochs: int, 32 train_text_encoder_cycles: int,
33 max_grad_norm: float = 1.0, 33 max_grad_norm: float = 1.0,
34 use_ema: bool = False, 34 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 35 ema_inv_gamma: float = 1.0,
@@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks(
85 return nullcontext() 85 return nullcontext()
86 86
87 @contextmanager 87 @contextmanager
88 def on_train(epoch: int): 88 def on_train(cycle: int):
89 unet.train() 89 unet.train()
90 tokenizer.train() 90 tokenizer.train()
91 91
92 if epoch < train_text_encoder_epochs: 92 if cycle < train_text_encoder_cycles:
93 text_encoder.train() 93 text_encoder.train()
94 elif epoch == train_text_encoder_epochs: 94 tokenizer.train()
95 text_encoder.requires_grad_(False)
96 text_encoder.eval()
97 95
98 yield 96 yield
99 97
@@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 104 with ema_context():
107 yield 105 yield
108 106
109 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
110 params_to_clip = [unet.parameters()] 108 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 109 if cycle < train_text_encoder_cycles:
112 params_to_clip.append(text_encoder.parameters()) 110 params_to_clip.append(text_encoder.parameters())
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 111 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 112
@@ -189,8 +187,16 @@ def dreambooth_prepare(
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs 188 **kwargs
191): 189):
192 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 190 (
193 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 191 text_encoder,
192 unet,
193 optimizer,
194 train_dataloader,
195 val_dataloader,
196 lr_scheduler,
197 ) = accelerator.prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 )
194 200
195 text_encoder.text_model.embeddings.requires_grad_(False) 201 text_encoder.text_model.embeddings.requires_grad_(False)
196 202
@@ -198,6 +204,5 @@ def dreambooth_prepare(
198 204
199 205
200dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(
201 callbacks=dreambooth_strategy_callbacks, 207 callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare
202 prepare=dreambooth_prepare
203) 208)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index f942b76..14e3384 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -81,7 +81,7 @@ def lora_strategy_callbacks(
81 tokenizer.eval() 81 tokenizer.eval()
82 yield 82 yield
83 83
84 def on_before_optimize(epoch: int): 84 def on_before_optimize(cycle: int):
85 if not pti_mode: 85 if not pti_mode:
86 accelerator.clip_grad_norm_( 86 accelerator.clip_grad_norm_(
87 itertools.chain( 87 itertools.chain(
@@ -89,7 +89,7 @@ def lora_strategy_callbacks(
89 text_encoder.text_model.encoder.parameters(), 89 text_encoder.text_model.encoder.parameters(),
90 text_encoder.text_model.final_layer_norm.parameters(), 90 text_encoder.text_model.final_layer_norm.parameters(),
91 ), 91 ),
92 max_grad_norm 92 max_grad_norm,
93 ) 93 )
94 94
95 if len(placeholder_tokens) != 0 and use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
@@ -108,7 +108,9 @@ def lora_strategy_callbacks(
108 108
109 if lambda_ != 0: 109 if lambda_ != 0:
110 norm = w[:, :].norm(dim=-1, keepdim=True) 110 norm = w[:, :].norm(dim=-1, keepdim=True)
111 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 111 w[:].add_(
112 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
113 )
112 114
113 @torch.no_grad() 115 @torch.no_grad()
114 def on_checkpoint(step, postfix): 116 def on_checkpoint(step, postfix):
@@ -128,25 +130,32 @@ def lora_strategy_callbacks(
128 130
129 if not pti_mode: 131 if not pti_mode:
130 lora_config = {} 132 lora_config = {}
131 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 133 state_dict = get_peft_model_state_dict(
134 unet_, state_dict=accelerator.get_state_dict(unet_)
135 )
132 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 136 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
133 137
134 text_encoder_state_dict = get_peft_model_state_dict( 138 text_encoder_state_dict = get_peft_model_state_dict(
135 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) 139 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
136 ) 140 )
137 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 141 text_encoder_state_dict = {
142 f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()
143 }
138 state_dict.update(text_encoder_state_dict) 144 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 145 lora_config[
146 "text_encoder_peft_config"
147 ] = text_encoder_.get_peft_config_as_dict(inference=True)
140 148
141 if len(placeholder_tokens) != 0: 149 if len(placeholder_tokens) != 0:
142 ti_state_dict = { 150 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) 151 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids) 152 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 } 153 }
147 state_dict.update(ti_state_dict) 154 state_dict.update(ti_state_dict)
148 155
149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 156 save_file(
157 state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors"
158 )
150 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 159 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
151 json.dump(lora_config, f) 160 json.dump(lora_config, f)
152 161
@@ -185,10 +194,18 @@ def lora_prepare(
185 train_dataloader: DataLoader, 194 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 195 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 196 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
188 **kwargs 197 **kwargs,
189): 198):
190 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 199 (
191 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 200 text_encoder,
201 unet,
202 optimizer,
203 train_dataloader,
204 val_dataloader,
205 lr_scheduler,
206 ) = accelerator.prepare(
207 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
208 )
192 209
193 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) 210 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
194 211
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6bc1d7d..7373982 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 119 ema_embeddings.step(
120 text_encoder.text_model.embeddings.token_embedding.parameters()
121 )
120 122
121 if use_emb_decay and w is not None: 123 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"] 124 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks(
124 126
125 if lambda_ != 0: 127 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True) 128 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 129 w[:].add_(
130 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
131 )
128 132
129 def on_log(): 133 def on_log():
130 if ema_embeddings is not None: 134 if ema_embeddings is not None:
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks(
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 with ema_context(): 142 with ema_context():
139 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
140 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
141 ids, 145 ids,
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 146 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin",
143 ) 147 )
144 148
145 @torch.no_grad() 149 @torch.no_grad()
@@ -183,7 +187,7 @@ def textual_inversion_prepare(
183 val_dataloader: Optional[DataLoader], 187 val_dataloader: Optional[DataLoader],
184 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 188 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
185 gradient_checkpointing: bool = False, 189 gradient_checkpointing: bool = False,
186 **kwargs 190 **kwargs,
187): 191):
188 weight_dtype = torch.float32 192 weight_dtype = torch.float32
189 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -191,8 +195,15 @@ def textual_inversion_prepare(
191 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
192 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
193 197
194 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 198 (
195 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 199 text_encoder,
200 optimizer,
201 train_dataloader,
202 val_dataloader,
203 lr_scheduler,
204 ) = accelerator.prepare(
205 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
206 )
196 207
197 unet.to(accelerator.device, dtype=weight_dtype) 208 unet.to(accelerator.device, dtype=weight_dtype)
198 unet.requires_grad_(False) 209 unet.requires_grad_(False)