From 1083830fa2f751476df2bed370f0468d39b37874 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 14:27:08 +0200 Subject: Freeze models that aren't trained --- dreambooth.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index b6b3594..89ed96a 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -111,7 +111,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-4, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -132,7 +132,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=500, + default=0, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -352,7 +352,7 @@ class Checkpointer: feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) pipeline.enable_attention_slicing() - pipeline.save_pretrained(f"{self.output_dir}/model.ckpt") + pipeline.save_pretrained(f"{self.output_dir}/model") del unwrapped del pipeline @@ -540,9 +540,9 @@ def main(): # slice_size = unet.config.attention_head_dim // 2 # unet.set_attention_slice(slice_size) - # Freeze vae and unet - # freeze_params(vae.parameters()) - # freeze_params(text_encoder.parameters()) + # Freeze text_encoder and vae + freeze_params(vae.parameters()) + freeze_params(text_encoder.parameters()) if args.scale_lr: args.learning_rate = ( @@ -644,8 +644,8 @@ def main(): vae.to(accelerator.device) # Keep text_encoder and vae in eval mode as we don't train these - # text_encoder.eval() - # vae.eval() + text_encoder.eval() + vae.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil( -- cgit v1.2.3-70-g09d2