summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py
index b6b3594..89ed96a 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -111,7 +111,7 @@ def parse_args():
111 parser.add_argument( 111 parser.add_argument(
112 "--learning_rate", 112 "--learning_rate",
113 type=float, 113 type=float,
114 default=1e-4, 114 default=5e-6,
115 help="Initial learning rate (after the potential warmup period) to use.", 115 help="Initial learning rate (after the potential warmup period) to use.",
116 ) 116 )
117 parser.add_argument( 117 parser.add_argument(
@@ -132,7 +132,7 @@ def parse_args():
132 parser.add_argument( 132 parser.add_argument(
133 "--lr_warmup_steps", 133 "--lr_warmup_steps",
134 type=int, 134 type=int,
135 default=500, 135 default=0,
136 help="Number of steps for the warmup in the lr scheduler." 136 help="Number of steps for the warmup in the lr scheduler."
137 ) 137 )
138 parser.add_argument( 138 parser.add_argument(
@@ -352,7 +352,7 @@ class Checkpointer:
352 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 352 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
353 ) 353 )
354 pipeline.enable_attention_slicing() 354 pipeline.enable_attention_slicing()
355 pipeline.save_pretrained(f"{self.output_dir}/model.ckpt") 355 pipeline.save_pretrained(f"{self.output_dir}/model")
356 356
357 del unwrapped 357 del unwrapped
358 del pipeline 358 del pipeline
@@ -540,9 +540,9 @@ def main():
540 # slice_size = unet.config.attention_head_dim // 2 540 # slice_size = unet.config.attention_head_dim // 2
541 # unet.set_attention_slice(slice_size) 541 # unet.set_attention_slice(slice_size)
542 542
543 # Freeze vae and unet 543 # Freeze text_encoder and vae
544 # freeze_params(vae.parameters()) 544 freeze_params(vae.parameters())
545 # freeze_params(text_encoder.parameters()) 545 freeze_params(text_encoder.parameters())
546 546
547 if args.scale_lr: 547 if args.scale_lr:
548 args.learning_rate = ( 548 args.learning_rate = (
@@ -644,8 +644,8 @@ def main():
644 vae.to(accelerator.device) 644 vae.to(accelerator.device)
645 645
646 # Keep text_encoder and vae in eval mode as we don't train these 646 # Keep text_encoder and vae in eval mode as we don't train these
647 # text_encoder.eval() 647 text_encoder.eval()
648 # vae.eval() 648 vae.eval()
649 649
650 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 650 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
651 num_update_steps_per_epoch = math.ceil( 651 num_update_steps_per_epoch = math.ceil(