From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/strategy/dreambooth.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) (limited to 'training/strategy/dreambooth.py') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, - train_text_encoder_epochs: int, + train_text_encoder_cycles: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() tokenizer.train() - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: text_encoder.train() - elif epoch == train_text_encoder_epochs: - text_encoder.requires_grad_(False) - text_encoder.eval() + tokenizer.train() yield @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): params_to_clip = [unet.parameters()] - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @@ -189,8 +187,16 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) text_encoder.text_model.embeddings.requires_grad_(False) @@ -198,6 +204,5 @@ def dreambooth_prepare( dreambooth_strategy = TrainingStrategy( - callbacks=dreambooth_strategy_callbacks, - prepare=dreambooth_prepare + callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare ) -- cgit v1.2.3-54-g00ecf