summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 16:48:51 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 16:48:51 +0200
commit55a12f2c683b2ecfa4fc8b4015462ad2798abda5 (patch)
treefeeb3f9a041466e773bb5921cbf0adb208d60a49 /train_lora.py
parentAvoid model recompilation due to varying prompt lengths (diff)
downloadtextual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.gz
textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.bz2
textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.zip
Fix LoRA training with DAdan
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py58
1 files changed, 27 insertions, 31 deletions
diff --git a/train_lora.py b/train_lora.py
index 12d7e72..c74dd8f 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -48,8 +48,8 @@ warnings.filterwarnings('ignore')
48torch.backends.cuda.matmul.allow_tf32 = True 48torch.backends.cuda.matmul.allow_tf32 = True
49torch.backends.cudnn.benchmark = True 49torch.backends.cudnn.benchmark = True
50 50
51torch._dynamo.config.log_level = logging.WARNING 51# torch._dynamo.config.log_level = logging.WARNING
52# torch._dynamo.config.suppress_errors = True 52torch._dynamo.config.suppress_errors = True
53 53
54hidet.torch.dynamo_config.use_tensor_core(True) 54hidet.torch.dynamo_config.use_tensor_core(True)
55hidet.torch.dynamo_config.search_space(0) 55hidet.torch.dynamo_config.search_space(0)
@@ -1143,6 +1143,28 @@ def main():
1143 avg_loss_val = AverageMeter() 1143 avg_loss_val = AverageMeter()
1144 avg_acc_val = AverageMeter() 1144 avg_acc_val = AverageMeter()
1145 1145
1146 params_to_optimize = [
1147 {
1148 "params": (
1149 param
1150 for param in unet.parameters()
1151 if param.requires_grad
1152 ),
1153 "lr": learning_rate_unet,
1154 },
1155 {
1156 "params": (
1157 param
1158 for param in text_encoder.parameters()
1159 if param.requires_grad
1160 ),
1161 "lr": learning_rate_text,
1162 }
1163 ]
1164 group_labels = ["unet", "text"]
1165
1166 lora_optimizer = create_optimizer(params_to_optimize)
1167
1146 while True: 1168 while True:
1147 if len(auto_cycles) != 0: 1169 if len(auto_cycles) != 0:
1148 response = auto_cycles.pop(0) 1170 response = auto_cycles.pop(0)
@@ -1182,35 +1204,9 @@ def main():
1182 print("") 1204 print("")
1183 print(f"============ LoRA cycle {training_iter + 1}: {response} ============") 1205 print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
1184 print("") 1206 print("")
1185 1207
1186 params_to_optimize = [] 1208 for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]):
1187 group_labels = [] 1209 group['lr'] = lr
1188
1189 params_to_optimize.append({
1190 "params": (
1191 param
1192 for param in unet.parameters()
1193 if param.requires_grad
1194 ),
1195 "lr": learning_rate_unet,
1196 })
1197 group_labels.append("unet")
1198
1199 if training_iter < args.train_text_encoder_cycles:
1200 params_to_optimize.append({
1201 "params": (
1202 param
1203 for param in itertools.chain(
1204 text_encoder.text_model.encoder.parameters(),
1205 text_encoder.text_model.final_layer_norm.parameters(),
1206 )
1207 if param.requires_grad
1208 ),
1209 "lr": learning_rate_text,
1210 })
1211 group_labels.append("text")
1212
1213 lora_optimizer = create_optimizer(params_to_optimize)
1214 1210
1215 lora_lr_scheduler = create_lr_scheduler( 1211 lora_lr_scheduler = create_lr_scheduler(
1216 lr_scheduler, 1212 lr_scheduler,