diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 14 |
1 files changed, 3 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py index 737af58..dea58cf 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -15,7 +15,7 @@ import hidet | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, get_peft_model |
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
| 20 | import transformers | 20 | import transformers |
| 21 | 21 | ||
| @@ -731,7 +731,7 @@ def main(): | |||
| 731 | lora_dropout=args.lora_dropout, | 731 | lora_dropout=args.lora_dropout, |
| 732 | bias=args.lora_bias, | 732 | bias=args.lora_bias, |
| 733 | ) | 733 | ) |
| 734 | unet = LoraModel(unet_config, unet) | 734 | unet = get_peft_model(unet, unet_config) |
| 735 | 735 | ||
| 736 | text_encoder_config = LoraConfig( | 736 | text_encoder_config = LoraConfig( |
| 737 | r=args.lora_text_encoder_r, | 737 | r=args.lora_text_encoder_r, |
| @@ -740,7 +740,7 @@ def main(): | |||
| 740 | lora_dropout=args.lora_text_encoder_dropout, | 740 | lora_dropout=args.lora_text_encoder_dropout, |
| 741 | bias=args.lora_text_encoder_bias, | 741 | bias=args.lora_text_encoder_bias, |
| 742 | ) | 742 | ) |
| 743 | text_encoder = LoraModel(text_encoder_config, text_encoder) | 743 | text_encoder = get_peft_model(text_encoder, text_encoder_config) |
| 744 | 744 | ||
| 745 | vae.enable_slicing() | 745 | vae.enable_slicing() |
| 746 | 746 | ||
| @@ -1167,14 +1167,6 @@ def main(): | |||
| 1167 | group_labels.append("unet") | 1167 | group_labels.append("unet") |
| 1168 | 1168 | ||
| 1169 | if training_iter < args.train_text_encoder_cycles: | 1169 | if training_iter < args.train_text_encoder_cycles: |
| 1170 | # if len(placeholder_tokens) != 0: | ||
| 1171 | # params_to_optimize.append({ | ||
| 1172 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 1173 | # "lr": learning_rate_emb, | ||
| 1174 | # "weight_decay": 0, | ||
| 1175 | # }) | ||
| 1176 | # group_labels.append("emb") | ||
| 1177 | |||
| 1178 | params_to_optimize.append({ | 1170 | params_to_optimize.append({ |
| 1179 | "params": ( | 1171 | "params": ( |
| 1180 | param | 1172 | param |
