summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py14
1 files changed, 3 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py
index 737af58..dea58cf 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -15,7 +15,7 @@ import hidet
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, LoraModel 18from peft import LoraConfig, get_peft_model
19# from diffusers.models.attention_processor import AttnProcessor 19# from diffusers.models.attention_processor import AttnProcessor
20import transformers 20import transformers
21 21
@@ -731,7 +731,7 @@ def main():
731 lora_dropout=args.lora_dropout, 731 lora_dropout=args.lora_dropout,
732 bias=args.lora_bias, 732 bias=args.lora_bias,
733 ) 733 )
734 unet = LoraModel(unet_config, unet) 734 unet = get_peft_model(unet, unet_config)
735 735
736 text_encoder_config = LoraConfig( 736 text_encoder_config = LoraConfig(
737 r=args.lora_text_encoder_r, 737 r=args.lora_text_encoder_r,
@@ -740,7 +740,7 @@ def main():
740 lora_dropout=args.lora_text_encoder_dropout, 740 lora_dropout=args.lora_text_encoder_dropout,
741 bias=args.lora_text_encoder_bias, 741 bias=args.lora_text_encoder_bias,
742 ) 742 )
743 text_encoder = LoraModel(text_encoder_config, text_encoder) 743 text_encoder = get_peft_model(text_encoder, text_encoder_config)
744 744
745 vae.enable_slicing() 745 vae.enable_slicing()
746 746
@@ -1167,14 +1167,6 @@ def main():
1167 group_labels.append("unet") 1167 group_labels.append("unet")
1168 1168
1169 if training_iter < args.train_text_encoder_cycles: 1169 if training_iter < args.train_text_encoder_cycles:
1170 # if len(placeholder_tokens) != 0:
1171 # params_to_optimize.append({
1172 # "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
1173 # "lr": learning_rate_emb,
1174 # "weight_decay": 0,
1175 # })
1176 # group_labels.append("emb")
1177
1178 params_to_optimize.append({ 1170 params_to_optimize.append({
1179 "params": ( 1171 "params": (
1180 param 1172 param