diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index c197206..9cf17c7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
24 | from training.strategy.lora import lora_strategy | 24 | from training.strategy.lora import lora_strategy |
25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
27 | 27 | ||
28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
@@ -1035,6 +1035,11 @@ def main(): | |||
1035 | lr_warmup_epochs = args.lr_warmup_epochs | 1035 | lr_warmup_epochs = args.lr_warmup_epochs |
1036 | lr_cycles = args.lr_cycles | 1036 | lr_cycles = args.lr_cycles |
1037 | 1037 | ||
1038 | avg_loss = AverageMeter() | ||
1039 | avg_acc = AverageMeter() | ||
1040 | avg_loss_val = AverageMeter() | ||
1041 | avg_acc_val = AverageMeter() | ||
1042 | |||
1038 | while True: | 1043 | while True: |
1039 | if len(auto_cycles) != 0: | 1044 | if len(auto_cycles) != 0: |
1040 | response = auto_cycles.pop(0) | 1045 | response = auto_cycles.pop(0) |
@@ -1122,7 +1127,7 @@ def main(): | |||
1122 | warmup_epochs=lr_warmup_epochs, | 1127 | warmup_epochs=lr_warmup_epochs, |
1123 | ) | 1128 | ) |
1124 | 1129 | ||
1125 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1130 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" |
1126 | 1131 | ||
1127 | trainer( | 1132 | trainer( |
1128 | strategy=lora_strategy, | 1133 | strategy=lora_strategy, |
@@ -1142,6 +1147,10 @@ def main(): | |||
1142 | sample_frequency=lora_sample_frequency, | 1147 | sample_frequency=lora_sample_frequency, |
1143 | offset_noise_strength=args.offset_noise_strength, | 1148 | offset_noise_strength=args.offset_noise_strength, |
1144 | no_val=args.valid_set_size == 0, | 1149 | no_val=args.valid_set_size == 0, |
1150 | avg_loss=avg_loss, | ||
1151 | avg_acc=avg_acc, | ||
1152 | avg_loss_val=avg_loss_val, | ||
1153 | avg_acc_val=avg_acc_val, | ||
1145 | ) | 1154 | ) |
1146 | 1155 | ||
1147 | training_iter += 1 | 1156 | training_iter += 1 |