diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-19 09:04:39 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-19 09:04:39 +0100 |
| commit | 2469501c3951a9ed86c820cddf7b32144a4a1c8d (patch) | |
| tree | 9820efaa12fd31670616c1fd9da3e6bb06580aaf /training/strategy/ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.gz textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.tar.bz2 textual-inversion-diff-2469501c3951a9ed86c820cddf7b32144a4a1c8d.zip | |
Move Accelerator preparation into strategy
Diffstat (limited to 'training/strategy/ti.py')
| -rw-r--r-- | training/strategy/ti.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index e922954..6a76f98 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| 8 | import torch.nn as nn | ||
| 8 | from torch.utils.data import DataLoader | 9 | from torch.utils.data import DataLoader |
| 9 | 10 | ||
| 10 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
| @@ -94,7 +95,7 @@ def textual_inversion_strategy_callbacks( | |||
| 94 | return nullcontext() | 95 | return nullcontext() |
| 95 | 96 | ||
| 96 | def on_model(): | 97 | def on_model(): |
| 97 | return text_encoder | 98 | return text_encoder.text_model.embeddings.temp_token_embedding |
| 98 | 99 | ||
| 99 | def on_prepare(): | 100 | def on_prepare(): |
| 100 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | 101 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) |
| @@ -163,6 +164,25 @@ def textual_inversion_strategy_callbacks( | |||
| 163 | ) | 164 | ) |
| 164 | 165 | ||
| 165 | 166 | ||
| 167 | def textual_inversion_prepare( | ||
| 168 | accelerator: Accelerator, | ||
| 169 | text_encoder: CLIPTextModel, | ||
| 170 | unet: UNet2DConditionModel, | ||
| 171 | *args | ||
| 172 | ): | ||
| 173 | weight_dtype = torch.float32 | ||
| 174 | if accelerator.state.mixed_precision == "fp16": | ||
| 175 | weight_dtype = torch.float16 | ||
| 176 | elif accelerator.state.mixed_precision == "bf16": | ||
| 177 | weight_dtype = torch.bfloat16 | ||
| 178 | |||
| 179 | prep = [text_encoder] + list(args) | ||
| 180 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(*prep) | ||
| 181 | unet.to(accelerator.device, dtype=weight_dtype) | ||
| 182 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 183 | |||
| 184 | |||
| 166 | textual_inversion_strategy = TrainingStrategy( | 185 | textual_inversion_strategy = TrainingStrategy( |
| 167 | callbacks=textual_inversion_strategy_callbacks, | 186 | callbacks=textual_inversion_strategy_callbacks, |
| 187 | prepare=textual_inversion_prepare, | ||
| 168 | ) | 188 | ) |
