From 06bfe1fccdc0976bacf9bfe2ae17d440fa416aab Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 19:16:46 +0200 Subject: Update --- train_dreambooth.py | 2 +- training/strategy/dreambooth.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 7745d27..d284346 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -262,7 +262,7 @@ def parse_args(): ) parser.add_argument( "--text_encoder_unfreeze_last_n_layers", - default=2, + default=-1, help="Number of text encoder layers to train.", ) parser.add_argument( diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 35cccbb..dc19ba3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -31,6 +31,7 @@ def dreambooth_strategy_callbacks( checkpoint_output_dir: Path, seed: int, train_text_encoder_cycles: int, + text_encoder_unfreeze_last_n_layers: int = 2, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -211,7 +212,7 @@ def dreambooth_prepare( ]: layer.requires_grad_(False) - text_encoder.text_model.embeddings.requires_grad_(False) + # text_encoder.text_model.embeddings.requires_grad_(False) return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler -- cgit v1.2.3-70-g09d2