summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py23
1 files changed, 12 insertions, 11 deletions
diff --git a/dreambooth.py b/dreambooth.py
index f7d31d2..02f83c6 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -129,7 +129,7 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate", 130 "--learning_rate",
131 type=float, 131 type=float,
132 default=1e-4, 132 default=1e-6,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
@@ -150,7 +150,7 @@ def parse_args():
150 parser.add_argument( 150 parser.add_argument(
151 "--lr_warmup_steps", 151 "--lr_warmup_steps",
152 type=int, 152 type=int,
153 default=200, 153 default=600,
154 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
155 ) 155 )
156 parser.add_argument( 156 parser.add_argument(
@@ -162,12 +162,12 @@ def parse_args():
162 parser.add_argument( 162 parser.add_argument(
163 "--ema_inv_gamma", 163 "--ema_inv_gamma",
164 type=float, 164 type=float,
165 default=0.1 165 default=1.0
166 ) 166 )
167 parser.add_argument( 167 parser.add_argument(
168 "--ema_power", 168 "--ema_power",
169 type=float, 169 type=float,
170 default=1 170 default=1.0
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
173 "--ema_max_decay", 173 "--ema_max_decay",
@@ -783,7 +783,12 @@ def main():
783 if global_step % args.sample_frequency == 0: 783 if global_step % args.sample_frequency == 0:
784 sample_checkpoint = True 784 sample_checkpoint = True
785 785
786 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 786 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
787 if args.use_ema:
788 logs["ema_decay"] = ema_unet.decay
789
790 accelerator.log(logs, step=global_step)
791
787 local_progress_bar.set_postfix(**logs) 792 local_progress_bar.set_postfix(**logs)
788 793
789 if global_step >= args.max_train_steps: 794 if global_step >= args.max_train_steps:
@@ -824,16 +829,12 @@ def main():
824 local_progress_bar.update(1) 829 local_progress_bar.update(1)
825 global_progress_bar.update(1) 830 global_progress_bar.update(1)
826 831
827 logs = {"mode": "validation", "loss": loss} 832 logs = {"val/loss": loss}
828 local_progress_bar.set_postfix(**logs) 833 local_progress_bar.set_postfix(**logs)
829 834
830 val_loss /= len(val_dataloader) 835 val_loss /= len(val_dataloader)
831 836
832 accelerator.log({ 837 accelerator.log({"val/loss": val_loss}, step=global_step)
833 "train/loss": train_loss,
834 "val/loss": val_loss,
835 "lr": lr_scheduler.get_last_lr()[0]
836 }, step=global_step)
837 838
838 local_progress_bar.clear() 839 local_progress_bar.clear()
839 global_progress_bar.clear() 840 global_progress_bar.clear()