diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48bdcf8..9c1e41c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| 5 | from functools import partial | 6 | from functools import partial |
| 6 | 7 | ||
| @@ -578,14 +579,11 @@ def main(): | |||
| 578 | datamodule.setup() | 579 | datamodule.setup() |
| 579 | 580 | ||
| 580 | optimizer = optimizer_class( | 581 | optimizer = optimizer_class( |
| 581 | [ | 582 | itertools.chain( |
| 582 | { | 583 | unet.parameters(), |
| 583 | 'params': unet.parameters(), | 584 | text_encoder.text_model.encoder.parameters(), |
| 584 | }, | 585 | text_encoder.text_model.final_layer_norm.parameters(), |
| 585 | { | 586 | ), |
| 586 | 'params': text_encoder.parameters(), | ||
| 587 | } | ||
| 588 | ], | ||
| 589 | lr=args.learning_rate, | 587 | lr=args.learning_rate, |
| 590 | betas=(args.adam_beta1, args.adam_beta2), | 588 | betas=(args.adam_beta1, args.adam_beta2), |
| 591 | weight_decay=args.adam_weight_decay, | 589 | weight_decay=args.adam_weight_decay, |
