diff options
-rw-r--r-- | dreambooth.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index b6b3594..89ed96a 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -111,7 +111,7 @@ def parse_args(): | |||
111 | parser.add_argument( | 111 | parser.add_argument( |
112 | "--learning_rate", | 112 | "--learning_rate", |
113 | type=float, | 113 | type=float, |
114 | default=1e-4, | 114 | default=5e-6, |
115 | help="Initial learning rate (after the potential warmup period) to use.", | 115 | help="Initial learning rate (after the potential warmup period) to use.", |
116 | ) | 116 | ) |
117 | parser.add_argument( | 117 | parser.add_argument( |
@@ -132,7 +132,7 @@ def parse_args(): | |||
132 | parser.add_argument( | 132 | parser.add_argument( |
133 | "--lr_warmup_steps", | 133 | "--lr_warmup_steps", |
134 | type=int, | 134 | type=int, |
135 | default=500, | 135 | default=0, |
136 | help="Number of steps for the warmup in the lr scheduler." | 136 | help="Number of steps for the warmup in the lr scheduler." |
137 | ) | 137 | ) |
138 | parser.add_argument( | 138 | parser.add_argument( |
@@ -352,7 +352,7 @@ class Checkpointer: | |||
352 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 352 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
353 | ) | 353 | ) |
354 | pipeline.enable_attention_slicing() | 354 | pipeline.enable_attention_slicing() |
355 | pipeline.save_pretrained(f"{self.output_dir}/model.ckpt") | 355 | pipeline.save_pretrained(f"{self.output_dir}/model") |
356 | 356 | ||
357 | del unwrapped | 357 | del unwrapped |
358 | del pipeline | 358 | del pipeline |
@@ -540,9 +540,9 @@ def main(): | |||
540 | # slice_size = unet.config.attention_head_dim // 2 | 540 | # slice_size = unet.config.attention_head_dim // 2 |
541 | # unet.set_attention_slice(slice_size) | 541 | # unet.set_attention_slice(slice_size) |
542 | 542 | ||
543 | # Freeze vae and unet | 543 | # Freeze text_encoder and vae |
544 | # freeze_params(vae.parameters()) | 544 | freeze_params(vae.parameters()) |
545 | # freeze_params(text_encoder.parameters()) | 545 | freeze_params(text_encoder.parameters()) |
546 | 546 | ||
547 | if args.scale_lr: | 547 | if args.scale_lr: |
548 | args.learning_rate = ( | 548 | args.learning_rate = ( |
@@ -644,8 +644,8 @@ def main(): | |||
644 | vae.to(accelerator.device) | 644 | vae.to(accelerator.device) |
645 | 645 | ||
646 | # Keep text_encoder and vae in eval mode as we don't train these | 646 | # Keep text_encoder and vae in eval mode as we don't train these |
647 | # text_encoder.eval() | 647 | text_encoder.eval() |
648 | # vae.eval() | 648 | vae.eval() |
649 | 649 | ||
650 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 650 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
651 | num_update_steps_per_epoch = math.ceil( | 651 | num_update_steps_per_epoch = math.ceil( |