diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
| commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
| tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/strategy/dreambooth.py | |
| parent | Fix LoRA training with DAdan (diff) | |
| download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip | |
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
| -rw-r--r-- | training/strategy/dreambooth.py | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( | |||
| 29 | sample_output_dir: Path, | 29 | sample_output_dir: Path, |
| 30 | checkpoint_output_dir: Path, | 30 | checkpoint_output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_cycles: int, |
| 33 | max_grad_norm: float = 1.0, | 33 | max_grad_norm: float = 1.0, |
| 34 | use_ema: bool = False, | 34 | use_ema: bool = False, |
| 35 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
| @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( | |||
| 85 | return nullcontext() | 85 | return nullcontext() |
| 86 | 86 | ||
| 87 | @contextmanager | 87 | @contextmanager |
| 88 | def on_train(epoch: int): | 88 | def on_train(cycle: int): |
| 89 | unet.train() | 89 | unet.train() |
| 90 | tokenizer.train() | 90 | tokenizer.train() |
| 91 | 91 | ||
| 92 | if epoch < train_text_encoder_epochs: | 92 | if cycle < train_text_encoder_cycles: |
| 93 | text_encoder.train() | 93 | text_encoder.train() |
| 94 | elif epoch == train_text_encoder_epochs: | 94 | tokenizer.train() |
| 95 | text_encoder.requires_grad_(False) | ||
| 96 | text_encoder.eval() | ||
| 97 | 95 | ||
| 98 | yield | 96 | yield |
| 99 | 97 | ||
| @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 104 | with ema_context(): |
| 107 | yield | 105 | yield |
| 108 | 106 | ||
| 109 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 110 | params_to_clip = [unet.parameters()] | 108 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 109 | if cycle < train_text_encoder_cycles: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 110 | params_to_clip.append(text_encoder.parameters()) |
| 113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 111 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
| 114 | 112 | ||
| @@ -189,8 +187,16 @@ def dreambooth_prepare( | |||
| 189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 190 | **kwargs | 188 | **kwargs |
| 191 | ): | 189 | ): |
| 192 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 190 | ( |
| 193 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 191 | text_encoder, |
| 192 | unet, | ||
| 193 | optimizer, | ||
| 194 | train_dataloader, | ||
| 195 | val_dataloader, | ||
| 196 | lr_scheduler, | ||
| 197 | ) = accelerator.prepare( | ||
| 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | ) | ||
| 194 | 200 | ||
| 195 | text_encoder.text_model.embeddings.requires_grad_(False) | 201 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 196 | 202 | ||
| @@ -198,6 +204,5 @@ def dreambooth_prepare( | |||
| 198 | 204 | ||
| 199 | 205 | ||
| 200 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
| 201 | callbacks=dreambooth_strategy_callbacks, | 207 | callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare |
| 202 | prepare=dreambooth_prepare | ||
| 203 | ) | 208 | ) |
