diff options
author | Volpeon <git@volpeon.ink> | 2022-10-10 17:55:08 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-10 17:55:08 +0200 |
commit | 35116fdf6fb1aedbe0da3cfa9372d53ddb455a26 (patch) | |
tree | 682e93c7b81c343fc64ecb5859e650083df15e4f /dreambooth.py | |
parent | Remove unused code (diff) | |
download | textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.gz textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.bz2 textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.zip |
Added EMA support to Textual Inversion
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index f7d31d2..02f83c6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -129,7 +129,7 @@ def parse_args(): | |||
129 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--learning_rate", | 130 | "--learning_rate", |
131 | type=float, | 131 | type=float, |
132 | default=1e-4, | 132 | default=1e-6, |
133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
@@ -150,7 +150,7 @@ def parse_args(): | |||
150 | parser.add_argument( | 150 | parser.add_argument( |
151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
152 | type=int, | 152 | type=int, |
153 | default=200, | 153 | default=600, |
154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
155 | ) | 155 | ) |
156 | parser.add_argument( | 156 | parser.add_argument( |
@@ -162,12 +162,12 @@ def parse_args(): | |||
162 | parser.add_argument( | 162 | parser.add_argument( |
163 | "--ema_inv_gamma", | 163 | "--ema_inv_gamma", |
164 | type=float, | 164 | type=float, |
165 | default=0.1 | 165 | default=1.0 |
166 | ) | 166 | ) |
167 | parser.add_argument( | 167 | parser.add_argument( |
168 | "--ema_power", | 168 | "--ema_power", |
169 | type=float, | 169 | type=float, |
170 | default=1 | 170 | default=1.0 |
171 | ) | 171 | ) |
172 | parser.add_argument( | 172 | parser.add_argument( |
173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
@@ -783,7 +783,12 @@ def main(): | |||
783 | if global_step % args.sample_frequency == 0: | 783 | if global_step % args.sample_frequency == 0: |
784 | sample_checkpoint = True | 784 | sample_checkpoint = True |
785 | 785 | ||
786 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 786 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
787 | if args.use_ema: | ||
788 | logs["ema_decay"] = ema_unet.decay | ||
789 | |||
790 | accelerator.log(logs, step=global_step) | ||
791 | |||
787 | local_progress_bar.set_postfix(**logs) | 792 | local_progress_bar.set_postfix(**logs) |
788 | 793 | ||
789 | if global_step >= args.max_train_steps: | 794 | if global_step >= args.max_train_steps: |
@@ -824,16 +829,12 @@ def main(): | |||
824 | local_progress_bar.update(1) | 829 | local_progress_bar.update(1) |
825 | global_progress_bar.update(1) | 830 | global_progress_bar.update(1) |
826 | 831 | ||
827 | logs = {"mode": "validation", "loss": loss} | 832 | logs = {"val/loss": loss} |
828 | local_progress_bar.set_postfix(**logs) | 833 | local_progress_bar.set_postfix(**logs) |
829 | 834 | ||
830 | val_loss /= len(val_dataloader) | 835 | val_loss /= len(val_dataloader) |
831 | 836 | ||
832 | accelerator.log({ | 837 | accelerator.log({"val/loss": val_loss}, step=global_step) |
833 | "train/loss": train_loss, | ||
834 | "val/loss": val_loss, | ||
835 | "lr": lr_scheduler.get_last_lr()[0] | ||
836 | }, step=global_step) | ||
837 | 838 | ||
838 | local_progress_bar.clear() | 839 | local_progress_bar.clear() |
839 | global_progress_bar.clear() | 840 | global_progress_bar.clear() |