summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/strategy/lora.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 80ffa9c..912ff26 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -34,6 +34,7 @@ def lora_strategy_callbacks(
34 seed: int, 34 seed: int,
35 placeholder_tokens: list[str], 35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]], 36 placeholder_token_ids: list[list[int]],
37 pti_mode: bool = False,
37 use_emb_decay: bool = False, 38 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4, 39 emb_decay_target: float = 0.4,
39 emb_decay: float = 1e-2, 40 emb_decay: float = 1e-2,
@@ -79,10 +80,11 @@ def lora_strategy_callbacks(
79 yield 80 yield
80 81
81 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(lr: float, epoch: int):
82 accelerator.clip_grad_norm_( 83 if not pti_mode:
83 itertools.chain(unet.parameters(), text_encoder.parameters()), 84 accelerator.clip_grad_norm_(
84 max_grad_norm 85 itertools.chain(unet.parameters(), text_encoder.parameters()),
85 ) 86 max_grad_norm
87 )
86 88
87 if use_emb_decay: 89 if use_emb_decay:
88 params = [ 90 params = [
@@ -117,20 +119,21 @@ def lora_strategy_callbacks(
117 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 119 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
118 ) 120 )
119 121
120 lora_config = {} 122 if not pti_mode:
121 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 123 lora_config = {}
122 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 124 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
125 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
123 126
124 text_encoder_state_dict = get_peft_model_state_dict( 127 text_encoder_state_dict = get_peft_model_state_dict(
125 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) 128 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
126 ) 129 )
127 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 130 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
128 state_dict.update(text_encoder_state_dict) 131 state_dict.update(text_encoder_state_dict)
129 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 132 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
130 133
131 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 134 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
132 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 135 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
133 json.dump(lora_config, f) 136 json.dump(lora_config, f)
134 137
135 del unet_ 138 del unet_
136 del text_encoder_ 139 del text_encoder_