summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48bdcf8..9c1e41c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -1,6 +1,7 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
4from pathlib import Path 5from pathlib import Path
5from functools import partial 6from functools import partial
6 7
@@ -578,14 +579,11 @@ def main():
578 datamodule.setup() 579 datamodule.setup()
579 580
580 optimizer = optimizer_class( 581 optimizer = optimizer_class(
581 [ 582 itertools.chain(
582 { 583 unet.parameters(),
583 'params': unet.parameters(), 584 text_encoder.text_model.encoder.parameters(),
584 }, 585 text_encoder.text_model.final_layer_norm.parameters(),
585 { 586 ),
586 'params': text_encoder.parameters(),
587 }
588 ],
589 lr=args.learning_rate, 587 lr=args.learning_rate,
590 betas=(args.adam_beta1, args.adam_beta2), 588 betas=(args.adam_beta1, args.adam_beta2),
591 weight_decay=args.adam_weight_decay, 589 weight_decay=args.adam_weight_decay,