diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48bdcf8..9c1e41c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | from functools import partial | 6 | from functools import partial |
6 | 7 | ||
@@ -578,14 +579,11 @@ def main(): | |||
578 | datamodule.setup() | 579 | datamodule.setup() |
579 | 580 | ||
580 | optimizer = optimizer_class( | 581 | optimizer = optimizer_class( |
581 | [ | 582 | itertools.chain( |
582 | { | 583 | unet.parameters(), |
583 | 'params': unet.parameters(), | 584 | text_encoder.text_model.encoder.parameters(), |
584 | }, | 585 | text_encoder.text_model.final_layer_norm.parameters(), |
585 | { | 586 | ), |
586 | 'params': text_encoder.parameters(), | ||
587 | } | ||
588 | ], | ||
589 | lr=args.learning_rate, | 587 | lr=args.learning_rate, |
590 | betas=(args.adam_beta1, args.adam_beta2), | 588 | betas=(args.adam_beta1, args.adam_beta2), |
591 | weight_decay=args.adam_weight_decay, | 589 | weight_decay=args.adam_weight_decay, |