diff options
Diffstat (limited to 'training/strategy')
| -rw-r--r-- | training/strategy/dreambooth.py | 29 | ||||
| -rw-r--r-- | training/strategy/lora.py | 41 | ||||
| -rw-r--r-- | training/strategy/ti.py | 27 |
3 files changed, 65 insertions, 32 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( | |||
| 29 | sample_output_dir: Path, | 29 | sample_output_dir: Path, |
| 30 | checkpoint_output_dir: Path, | 30 | checkpoint_output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_cycles: int, |
| 33 | max_grad_norm: float = 1.0, | 33 | max_grad_norm: float = 1.0, |
| 34 | use_ema: bool = False, | 34 | use_ema: bool = False, |
| 35 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
| @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( | |||
| 85 | return nullcontext() | 85 | return nullcontext() |
| 86 | 86 | ||
| 87 | @contextmanager | 87 | @contextmanager |
| 88 | def on_train(epoch: int): | 88 | def on_train(cycle: int): |
| 89 | unet.train() | 89 | unet.train() |
| 90 | tokenizer.train() | 90 | tokenizer.train() |
| 91 | 91 | ||
| 92 | if epoch < train_text_encoder_epochs: | 92 | if cycle < train_text_encoder_cycles: |
| 93 | text_encoder.train() | 93 | text_encoder.train() |
| 94 | elif epoch == train_text_encoder_epochs: | 94 | tokenizer.train() |
| 95 | text_encoder.requires_grad_(False) | ||
| 96 | text_encoder.eval() | ||
| 97 | 95 | ||
| 98 | yield | 96 | yield |
| 99 | 97 | ||
| @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 104 | with ema_context(): |
| 107 | yield | 105 | yield |
| 108 | 106 | ||
| 109 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 110 | params_to_clip = [unet.parameters()] | 108 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 109 | if cycle < train_text_encoder_cycles: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 110 | params_to_clip.append(text_encoder.parameters()) |
| 113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 111 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
| 114 | 112 | ||
| @@ -189,8 +187,16 @@ def dreambooth_prepare( | |||
| 189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 190 | **kwargs | 188 | **kwargs |
| 191 | ): | 189 | ): |
| 192 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 190 | ( |
| 193 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 191 | text_encoder, |
| 192 | unet, | ||
| 193 | optimizer, | ||
| 194 | train_dataloader, | ||
| 195 | val_dataloader, | ||
| 196 | lr_scheduler, | ||
| 197 | ) = accelerator.prepare( | ||
| 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | ) | ||
| 194 | 200 | ||
| 195 | text_encoder.text_model.embeddings.requires_grad_(False) | 201 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 196 | 202 | ||
| @@ -198,6 +204,5 @@ def dreambooth_prepare( | |||
| 198 | 204 | ||
| 199 | 205 | ||
| 200 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
| 201 | callbacks=dreambooth_strategy_callbacks, | 207 | callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare |
| 202 | prepare=dreambooth_prepare | ||
| 203 | ) | 208 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -81,7 +81,7 @@ def lora_strategy_callbacks( | |||
| 81 | tokenizer.eval() | 81 | tokenizer.eval() |
| 82 | yield | 82 | yield |
| 83 | 83 | ||
| 84 | def on_before_optimize(epoch: int): | 84 | def on_before_optimize(cycle: int): |
| 85 | if not pti_mode: | 85 | if not pti_mode: |
| 86 | accelerator.clip_grad_norm_( | 86 | accelerator.clip_grad_norm_( |
| 87 | itertools.chain( | 87 | itertools.chain( |
| @@ -89,7 +89,7 @@ def lora_strategy_callbacks( | |||
| 89 | text_encoder.text_model.encoder.parameters(), | 89 | text_encoder.text_model.encoder.parameters(), |
| 90 | text_encoder.text_model.final_layer_norm.parameters(), | 90 | text_encoder.text_model.final_layer_norm.parameters(), |
| 91 | ), | 91 | ), |
| 92 | max_grad_norm | 92 | max_grad_norm, |
| 93 | ) | 93 | ) |
| 94 | 94 | ||
| 95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
| @@ -108,7 +108,9 @@ def lora_strategy_callbacks( | |||
| 108 | 108 | ||
| 109 | if lambda_ != 0: | 109 | if lambda_ != 0: |
| 110 | norm = w[:, :].norm(dim=-1, keepdim=True) | 110 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 111 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 111 | w[:].add_( |
| 112 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 113 | ) | ||
| 112 | 114 | ||
| 113 | @torch.no_grad() | 115 | @torch.no_grad() |
| 114 | def on_checkpoint(step, postfix): | 116 | def on_checkpoint(step, postfix): |
| @@ -128,25 +130,32 @@ def lora_strategy_callbacks( | |||
| 128 | 130 | ||
| 129 | if not pti_mode: | 131 | if not pti_mode: |
| 130 | lora_config = {} | 132 | lora_config = {} |
| 131 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 133 | state_dict = get_peft_model_state_dict( |
| 134 | unet_, state_dict=accelerator.get_state_dict(unet_) | ||
| 135 | ) | ||
| 132 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 136 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
| 133 | 137 | ||
| 134 | text_encoder_state_dict = get_peft_model_state_dict( | 138 | text_encoder_state_dict = get_peft_model_state_dict( |
| 135 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 139 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
| 136 | ) | 140 | ) |
| 137 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 141 | text_encoder_state_dict = { |
| 142 | f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() | ||
| 143 | } | ||
| 138 | state_dict.update(text_encoder_state_dict) | 144 | state_dict.update(text_encoder_state_dict) |
| 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 145 | lora_config[ |
| 146 | "text_encoder_peft_config" | ||
| 147 | ] = text_encoder_.get_peft_config_as_dict(inference=True) | ||
| 140 | 148 | ||
| 141 | if len(placeholder_tokens) != 0: | 149 | if len(placeholder_tokens) != 0: |
| 142 | ti_state_dict = { | 150 | ti_state_dict = { |
| 143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | 151 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) |
| 144 | for (token, ids) | 152 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) |
| 145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
| 146 | } | 153 | } |
| 147 | state_dict.update(ti_state_dict) | 154 | state_dict.update(ti_state_dict) |
| 148 | 155 | ||
| 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 156 | save_file( |
| 157 | state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" | ||
| 158 | ) | ||
| 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 159 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
| 151 | json.dump(lora_config, f) | 160 | json.dump(lora_config, f) |
| 152 | 161 | ||
| @@ -185,10 +194,18 @@ def lora_prepare( | |||
| 185 | train_dataloader: DataLoader, | 194 | train_dataloader: DataLoader, |
| 186 | val_dataloader: Optional[DataLoader], | 195 | val_dataloader: Optional[DataLoader], |
| 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 196 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 188 | **kwargs | 197 | **kwargs, |
| 189 | ): | 198 | ): |
| 190 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 199 | ( |
| 191 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 200 | text_encoder, |
| 201 | unet, | ||
| 202 | optimizer, | ||
| 203 | train_dataloader, | ||
| 204 | val_dataloader, | ||
| 205 | lr_scheduler, | ||
| 206 | ) = accelerator.prepare( | ||
| 207 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 208 | ) | ||
| 192 | 209 | ||
| 193 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) | 210 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
| 194 | 211 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
| @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step( |
| 120 | text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 121 | ) | ||
| 120 | 122 | ||
| 121 | if use_emb_decay and w is not None: | 123 | if use_emb_decay and w is not None: |
| 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 124 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
| @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( | |||
| 124 | 126 | ||
| 125 | if lambda_ != 0: | 127 | if lambda_ != 0: |
| 126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 128 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 129 | w[:].add_( |
| 130 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 131 | ) | ||
| 128 | 132 | ||
| 129 | def on_log(): | 133 | def on_log(): |
| 130 | if ema_embeddings is not None: | 134 | if ema_embeddings is not None: |
| @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( | |||
| 136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
| 137 | 141 | ||
| 138 | with ema_context(): | 142 | with ema_context(): |
| 139 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
| 140 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
| 141 | ids, | 145 | ids, |
| 142 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 146 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", |
| 143 | ) | 147 | ) |
| 144 | 148 | ||
| 145 | @torch.no_grad() | 149 | @torch.no_grad() |
| @@ -183,7 +187,7 @@ def textual_inversion_prepare( | |||
| 183 | val_dataloader: Optional[DataLoader], | 187 | val_dataloader: Optional[DataLoader], |
| 184 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 188 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 185 | gradient_checkpointing: bool = False, | 189 | gradient_checkpointing: bool = False, |
| 186 | **kwargs | 190 | **kwargs, |
| 187 | ): | 191 | ): |
| 188 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
| 189 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
| @@ -191,8 +195,15 @@ def textual_inversion_prepare( | |||
| 191 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
| 192 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
| 193 | 197 | ||
| 194 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 198 | ( |
| 195 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 199 | text_encoder, |
| 200 | optimizer, | ||
| 201 | train_dataloader, | ||
| 202 | val_dataloader, | ||
| 203 | lr_scheduler, | ||
| 204 | ) = accelerator.prepare( | ||
| 205 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 206 | ) | ||
| 196 | 207 | ||
| 197 | unet.to(accelerator.device, dtype=weight_dtype) | 208 | unet.to(accelerator.device, dtype=weight_dtype) |
| 198 | unet.requires_grad_(False) | 209 | unet.requires_grad_(False) |
