summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-22 19:16:46 +0200
committerVolpeon <git@volpeon.ink>2023-06-22 19:16:46 +0200
commit06bfe1fccdc0976bacf9bfe2ae17d440fa416aab (patch)
tree5be3b72811de6b66d8c5a2e641b589363a47ec32
parentAdded prompt dropout (diff)
downloadtextual-inversion-diff-06bfe1fccdc0976bacf9bfe2ae17d440fa416aab.tar.gz
textual-inversion-diff-06bfe1fccdc0976bacf9bfe2ae17d440fa416aab.tar.bz2
textual-inversion-diff-06bfe1fccdc0976bacf9bfe2ae17d440fa416aab.zip
Update
-rw-r--r--train_dreambooth.py2
-rw-r--r--training/strategy/dreambooth.py3
2 files changed, 3 insertions, 2 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 7745d27..d284346 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -262,7 +262,7 @@ def parse_args():
262 ) 262 )
263 parser.add_argument( 263 parser.add_argument(
264 "--text_encoder_unfreeze_last_n_layers", 264 "--text_encoder_unfreeze_last_n_layers",
265 default=2, 265 default=-1,
266 help="Number of text encoder layers to train.", 266 help="Number of text encoder layers to train.",
267 ) 267 )
268 parser.add_argument( 268 parser.add_argument(
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 35cccbb..dc19ba3 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -31,6 +31,7 @@ def dreambooth_strategy_callbacks(
31 checkpoint_output_dir: Path, 31 checkpoint_output_dir: Path,
32 seed: int, 32 seed: int,
33 train_text_encoder_cycles: int, 33 train_text_encoder_cycles: int,
34 text_encoder_unfreeze_last_n_layers: int = 2,
34 max_grad_norm: float = 1.0, 35 max_grad_norm: float = 1.0,
35 use_ema: bool = False, 36 use_ema: bool = False,
36 ema_inv_gamma: float = 1.0, 37 ema_inv_gamma: float = 1.0,
@@ -211,7 +212,7 @@ def dreambooth_prepare(
211 ]: 212 ]:
212 layer.requires_grad_(False) 213 layer.requires_grad_(False)
213 214
214 text_encoder.text_model.embeddings.requires_grad_(False) 215 # text_encoder.text_model.embeddings.requires_grad_(False)
215 216
216 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 217 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
217 218