diff options
author | Volpeon <git@volpeon.ink> | 2023-05-16 16:48:51 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-05-16 16:48:51 +0200 |
commit | 55a12f2c683b2ecfa4fc8b4015462ad2798abda5 (patch) | |
tree | feeb3f9a041466e773bb5921cbf0adb208d60a49 /train_lora.py | |
parent | Avoid model recompilation due to varying prompt lengths (diff) | |
download | textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.gz textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.bz2 textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.zip |
Fix LoRA training with DAdan
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 58 |
1 files changed, 27 insertions, 31 deletions
diff --git a/train_lora.py b/train_lora.py index 12d7e72..c74dd8f 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -48,8 +48,8 @@ warnings.filterwarnings('ignore') | |||
48 | torch.backends.cuda.matmul.allow_tf32 = True | 48 | torch.backends.cuda.matmul.allow_tf32 = True |
49 | torch.backends.cudnn.benchmark = True | 49 | torch.backends.cudnn.benchmark = True |
50 | 50 | ||
51 | torch._dynamo.config.log_level = logging.WARNING | 51 | # torch._dynamo.config.log_level = logging.WARNING |
52 | # torch._dynamo.config.suppress_errors = True | 52 | torch._dynamo.config.suppress_errors = True |
53 | 53 | ||
54 | hidet.torch.dynamo_config.use_tensor_core(True) | 54 | hidet.torch.dynamo_config.use_tensor_core(True) |
55 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
@@ -1143,6 +1143,28 @@ def main(): | |||
1143 | avg_loss_val = AverageMeter() | 1143 | avg_loss_val = AverageMeter() |
1144 | avg_acc_val = AverageMeter() | 1144 | avg_acc_val = AverageMeter() |
1145 | 1145 | ||
1146 | params_to_optimize = [ | ||
1147 | { | ||
1148 | "params": ( | ||
1149 | param | ||
1150 | for param in unet.parameters() | ||
1151 | if param.requires_grad | ||
1152 | ), | ||
1153 | "lr": learning_rate_unet, | ||
1154 | }, | ||
1155 | { | ||
1156 | "params": ( | ||
1157 | param | ||
1158 | for param in text_encoder.parameters() | ||
1159 | if param.requires_grad | ||
1160 | ), | ||
1161 | "lr": learning_rate_text, | ||
1162 | } | ||
1163 | ] | ||
1164 | group_labels = ["unet", "text"] | ||
1165 | |||
1166 | lora_optimizer = create_optimizer(params_to_optimize) | ||
1167 | |||
1146 | while True: | 1168 | while True: |
1147 | if len(auto_cycles) != 0: | 1169 | if len(auto_cycles) != 0: |
1148 | response = auto_cycles.pop(0) | 1170 | response = auto_cycles.pop(0) |
@@ -1182,35 +1204,9 @@ def main(): | |||
1182 | print("") | 1204 | print("") |
1183 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1205 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
1184 | print("") | 1206 | print("") |
1185 | 1207 | ||
1186 | params_to_optimize = [] | 1208 | for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): |
1187 | group_labels = [] | 1209 | group['lr'] = lr |
1188 | |||
1189 | params_to_optimize.append({ | ||
1190 | "params": ( | ||
1191 | param | ||
1192 | for param in unet.parameters() | ||
1193 | if param.requires_grad | ||
1194 | ), | ||
1195 | "lr": learning_rate_unet, | ||
1196 | }) | ||
1197 | group_labels.append("unet") | ||
1198 | |||
1199 | if training_iter < args.train_text_encoder_cycles: | ||
1200 | params_to_optimize.append({ | ||
1201 | "params": ( | ||
1202 | param | ||
1203 | for param in itertools.chain( | ||
1204 | text_encoder.text_model.encoder.parameters(), | ||
1205 | text_encoder.text_model.final_layer_norm.parameters(), | ||
1206 | ) | ||
1207 | if param.requires_grad | ||
1208 | ), | ||
1209 | "lr": learning_rate_text, | ||
1210 | }) | ||
1211 | group_labels.append("text") | ||
1212 | |||
1213 | lora_optimizer = create_optimizer(params_to_optimize) | ||
1214 | 1210 | ||
1215 | lora_lr_scheduler = create_lr_scheduler( | 1211 | lora_lr_scheduler = create_lr_scheduler( |
1216 | lr_scheduler, | 1212 | lr_scheduler, |