diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-10 17:55:08 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-10 17:55:08 +0200 |
| commit | 35116fdf6fb1aedbe0da3cfa9372d53ddb455a26 (patch) | |
| tree | 682e93c7b81c343fc64ecb5859e650083df15e4f /dreambooth.py | |
| parent | Remove unused code (diff) | |
| download | textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.gz textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.tar.bz2 textual-inversion-diff-35116fdf6fb1aedbe0da3cfa9372d53ddb455a26.zip | |
Added EMA support to Textual Inversion
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py index f7d31d2..02f83c6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -129,7 +129,7 @@ def parse_args(): | |||
| 129 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--learning_rate", | 130 | "--learning_rate", |
| 131 | type=float, | 131 | type=float, |
| 132 | default=1e-4, | 132 | default=1e-6, |
| 133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| @@ -150,7 +150,7 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 150 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
| 152 | type=int, | 152 | type=int, |
| 153 | default=200, | 153 | default=600, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 155 | ) |
| 156 | parser.add_argument( | 156 | parser.add_argument( |
| @@ -162,12 +162,12 @@ def parse_args(): | |||
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| 163 | "--ema_inv_gamma", | 163 | "--ema_inv_gamma", |
| 164 | type=float, | 164 | type=float, |
| 165 | default=0.1 | 165 | default=1.0 |
| 166 | ) | 166 | ) |
| 167 | parser.add_argument( | 167 | parser.add_argument( |
| 168 | "--ema_power", | 168 | "--ema_power", |
| 169 | type=float, | 169 | type=float, |
| 170 | default=1 | 170 | default=1.0 |
| 171 | ) | 171 | ) |
| 172 | parser.add_argument( | 172 | parser.add_argument( |
| 173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
| @@ -783,7 +783,12 @@ def main(): | |||
| 783 | if global_step % args.sample_frequency == 0: | 783 | if global_step % args.sample_frequency == 0: |
| 784 | sample_checkpoint = True | 784 | sample_checkpoint = True |
| 785 | 785 | ||
| 786 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 786 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 787 | if args.use_ema: | ||
| 788 | logs["ema_decay"] = ema_unet.decay | ||
| 789 | |||
| 790 | accelerator.log(logs, step=global_step) | ||
| 791 | |||
| 787 | local_progress_bar.set_postfix(**logs) | 792 | local_progress_bar.set_postfix(**logs) |
| 788 | 793 | ||
| 789 | if global_step >= args.max_train_steps: | 794 | if global_step >= args.max_train_steps: |
| @@ -824,16 +829,12 @@ def main(): | |||
| 824 | local_progress_bar.update(1) | 829 | local_progress_bar.update(1) |
| 825 | global_progress_bar.update(1) | 830 | global_progress_bar.update(1) |
| 826 | 831 | ||
| 827 | logs = {"mode": "validation", "loss": loss} | 832 | logs = {"val/loss": loss} |
| 828 | local_progress_bar.set_postfix(**logs) | 833 | local_progress_bar.set_postfix(**logs) |
| 829 | 834 | ||
| 830 | val_loss /= len(val_dataloader) | 835 | val_loss /= len(val_dataloader) |
| 831 | 836 | ||
| 832 | accelerator.log({ | 837 | accelerator.log({"val/loss": val_loss}, step=global_step) |
| 833 | "train/loss": train_loss, | ||
| 834 | "val/loss": val_loss, | ||
| 835 | "lr": lr_scheduler.get_last_lr()[0] | ||
| 836 | }, step=global_step) | ||
| 837 | 838 | ||
| 838 | local_progress_bar.clear() | 839 | local_progress_bar.clear() |
| 839 | global_progress_bar.clear() | 840 | global_progress_bar.clear() |
