diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 35 |
1 files changed, 19 insertions, 16 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 80ffa9c..912ff26 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -34,6 +34,7 @@ def lora_strategy_callbacks( | |||
34 | seed: int, | 34 | seed: int, |
35 | placeholder_tokens: list[str], | 35 | placeholder_tokens: list[str], |
36 | placeholder_token_ids: list[list[int]], | 36 | placeholder_token_ids: list[list[int]], |
37 | pti_mode: bool = False, | ||
37 | use_emb_decay: bool = False, | 38 | use_emb_decay: bool = False, |
38 | emb_decay_target: float = 0.4, | 39 | emb_decay_target: float = 0.4, |
39 | emb_decay: float = 1e-2, | 40 | emb_decay: float = 1e-2, |
@@ -79,10 +80,11 @@ def lora_strategy_callbacks( | |||
79 | yield | 80 | yield |
80 | 81 | ||
81 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(lr: float, epoch: int): |
82 | accelerator.clip_grad_norm_( | 83 | if not pti_mode: |
83 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 84 | accelerator.clip_grad_norm_( |
84 | max_grad_norm | 85 | itertools.chain(unet.parameters(), text_encoder.parameters()), |
85 | ) | 86 | max_grad_norm |
87 | ) | ||
86 | 88 | ||
87 | if use_emb_decay: | 89 | if use_emb_decay: |
88 | params = [ | 90 | params = [ |
@@ -117,20 +119,21 @@ def lora_strategy_callbacks( | |||
117 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 119 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
118 | ) | 120 | ) |
119 | 121 | ||
120 | lora_config = {} | 122 | if not pti_mode: |
121 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 123 | lora_config = {} |
122 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 124 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
125 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | ||
123 | 126 | ||
124 | text_encoder_state_dict = get_peft_model_state_dict( | 127 | text_encoder_state_dict = get_peft_model_state_dict( |
125 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 128 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
126 | ) | 129 | ) |
127 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 130 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} |
128 | state_dict.update(text_encoder_state_dict) | 131 | state_dict.update(text_encoder_state_dict) |
129 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 132 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
130 | 133 | ||
131 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 134 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
132 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 135 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
133 | json.dump(lora_config, f) | 136 | json.dump(lora_config, f) |
134 | 137 | ||
135 | del unet_ | 138 | del unet_ |
136 | del text_encoder_ | 139 | del text_encoder_ |