diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index c197206..9cf17c7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| 24 | from training.strategy.lora import lora_strategy | 24 | from training.strategy.lora import lora_strategy |
| 25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
| 26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
| 27 | 27 | ||
| 28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| 29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
| @@ -1035,6 +1035,11 @@ def main(): | |||
| 1035 | lr_warmup_epochs = args.lr_warmup_epochs | 1035 | lr_warmup_epochs = args.lr_warmup_epochs |
| 1036 | lr_cycles = args.lr_cycles | 1036 | lr_cycles = args.lr_cycles |
| 1037 | 1037 | ||
| 1038 | avg_loss = AverageMeter() | ||
| 1039 | avg_acc = AverageMeter() | ||
| 1040 | avg_loss_val = AverageMeter() | ||
| 1041 | avg_acc_val = AverageMeter() | ||
| 1042 | |||
| 1038 | while True: | 1043 | while True: |
| 1039 | if len(auto_cycles) != 0: | 1044 | if len(auto_cycles) != 0: |
| 1040 | response = auto_cycles.pop(0) | 1045 | response = auto_cycles.pop(0) |
| @@ -1122,7 +1127,7 @@ def main(): | |||
| 1122 | warmup_epochs=lr_warmup_epochs, | 1127 | warmup_epochs=lr_warmup_epochs, |
| 1123 | ) | 1128 | ) |
| 1124 | 1129 | ||
| 1125 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1130 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" |
| 1126 | 1131 | ||
| 1127 | trainer( | 1132 | trainer( |
| 1128 | strategy=lora_strategy, | 1133 | strategy=lora_strategy, |
| @@ -1142,6 +1147,10 @@ def main(): | |||
| 1142 | sample_frequency=lora_sample_frequency, | 1147 | sample_frequency=lora_sample_frequency, |
| 1143 | offset_noise_strength=args.offset_noise_strength, | 1148 | offset_noise_strength=args.offset_noise_strength, |
| 1144 | no_val=args.valid_set_size == 0, | 1149 | no_val=args.valid_set_size == 0, |
| 1150 | avg_loss=avg_loss, | ||
| 1151 | avg_acc=avg_acc, | ||
| 1152 | avg_loss_val=avg_loss_val, | ||
| 1153 | avg_acc_val=avg_acc_val, | ||
| 1145 | ) | 1154 | ) |
| 1146 | 1155 | ||
| 1147 | training_iter += 1 | 1156 | training_iter += 1 |
