diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 14 |
1 files changed, 3 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py index 737af58..dea58cf 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -15,7 +15,7 @@ import hidet | |||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from peft import LoraConfig, LoraModel | 18 | from peft import LoraConfig, get_peft_model |
19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
20 | import transformers | 20 | import transformers |
21 | 21 | ||
@@ -731,7 +731,7 @@ def main(): | |||
731 | lora_dropout=args.lora_dropout, | 731 | lora_dropout=args.lora_dropout, |
732 | bias=args.lora_bias, | 732 | bias=args.lora_bias, |
733 | ) | 733 | ) |
734 | unet = LoraModel(unet_config, unet) | 734 | unet = get_peft_model(unet, unet_config) |
735 | 735 | ||
736 | text_encoder_config = LoraConfig( | 736 | text_encoder_config = LoraConfig( |
737 | r=args.lora_text_encoder_r, | 737 | r=args.lora_text_encoder_r, |
@@ -740,7 +740,7 @@ def main(): | |||
740 | lora_dropout=args.lora_text_encoder_dropout, | 740 | lora_dropout=args.lora_text_encoder_dropout, |
741 | bias=args.lora_text_encoder_bias, | 741 | bias=args.lora_text_encoder_bias, |
742 | ) | 742 | ) |
743 | text_encoder = LoraModel(text_encoder_config, text_encoder) | 743 | text_encoder = get_peft_model(text_encoder, text_encoder_config) |
744 | 744 | ||
745 | vae.enable_slicing() | 745 | vae.enable_slicing() |
746 | 746 | ||
@@ -1167,14 +1167,6 @@ def main(): | |||
1167 | group_labels.append("unet") | 1167 | group_labels.append("unet") |
1168 | 1168 | ||
1169 | if training_iter < args.train_text_encoder_cycles: | 1169 | if training_iter < args.train_text_encoder_cycles: |
1170 | # if len(placeholder_tokens) != 0: | ||
1171 | # params_to_optimize.append({ | ||
1172 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
1173 | # "lr": learning_rate_emb, | ||
1174 | # "weight_decay": 0, | ||
1175 | # }) | ||
1176 | # group_labels.append("emb") | ||
1177 | |||
1178 | params_to_optimize.append({ | 1170 | params_to_optimize.append({ |
1179 | "params": ( | 1171 | "params": ( |
1180 | param | 1172 | param |