summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index c197206..9cf17c7 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter
23from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
24from training.strategy.lora import lora_strategy 24from training.strategy.lora import lora_strategy
25from training.optimization import get_scheduler 25from training.optimization import get_scheduler
26from training.util import save_args 26from training.util import AverageMeter, save_args
27 27
28# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 28# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
29UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] 29UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
@@ -1035,6 +1035,11 @@ def main():
1035 lr_warmup_epochs = args.lr_warmup_epochs 1035 lr_warmup_epochs = args.lr_warmup_epochs
1036 lr_cycles = args.lr_cycles 1036 lr_cycles = args.lr_cycles
1037 1037
1038 avg_loss = AverageMeter()
1039 avg_acc = AverageMeter()
1040 avg_loss_val = AverageMeter()
1041 avg_acc_val = AverageMeter()
1042
1038 while True: 1043 while True:
1039 if len(auto_cycles) != 0: 1044 if len(auto_cycles) != 0:
1040 response = auto_cycles.pop(0) 1045 response = auto_cycles.pop(0)
@@ -1122,7 +1127,7 @@ def main():
1122 warmup_epochs=lr_warmup_epochs, 1127 warmup_epochs=lr_warmup_epochs,
1123 ) 1128 )
1124 1129
1125 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" 1130 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}"
1126 1131
1127 trainer( 1132 trainer(
1128 strategy=lora_strategy, 1133 strategy=lora_strategy,
@@ -1142,6 +1147,10 @@ def main():
1142 sample_frequency=lora_sample_frequency, 1147 sample_frequency=lora_sample_frequency,
1143 offset_noise_strength=args.offset_noise_strength, 1148 offset_noise_strength=args.offset_noise_strength,
1144 no_val=args.valid_set_size == 0, 1149 no_val=args.valid_set_size == 0,
1150 avg_loss=avg_loss,
1151 avg_acc=avg_acc,
1152 avg_loss_val=avg_loss_val,
1153 avg_acc_val=avg_acc_val,
1145 ) 1154 )
1146 1155
1147 training_iter += 1 1156 training_iter += 1