diff options
author | Volpeon <git@volpeon.ink> | 2023-03-24 10:53:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-24 10:53:16 +0100 |
commit | 95adaea8b55d8e3755c035758bc649ae22548572 (patch) | |
tree | 80239f0bc55b99615718a935be2caa2e1e68e20a /training/strategy/dreambooth.py | |
parent | Bring back Perlin offset noise (diff) | |
download | textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2 textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip |
Refactoring, fixed Lora training
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r-- | training/strategy/dreambooth.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 28fccff..9808027 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks( | |||
74 | power=ema_power, | 74 | power=ema_power, |
75 | max_value=ema_max_decay, | 75 | max_value=ema_max_decay, |
76 | ) | 76 | ) |
77 | ema_unet.to(accelerator.device) | ||
77 | else: | 78 | else: |
78 | ema_unet = None | 79 | ema_unet = None |
79 | 80 | ||
@@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks( | |||
86 | def on_accum_model(): | 87 | def on_accum_model(): |
87 | return unet | 88 | return unet |
88 | 89 | ||
89 | def on_prepare(): | ||
90 | unet.requires_grad_(True) | ||
91 | text_encoder.text_model.encoder.requires_grad_(True) | ||
92 | text_encoder.text_model.final_layer_norm.requires_grad_(True) | ||
93 | |||
94 | if ema_unet is not None: | ||
95 | ema_unet.to(accelerator.device) | ||
96 | |||
97 | @contextmanager | 90 | @contextmanager |
98 | def on_train(epoch: int): | 91 | def on_train(epoch: int): |
99 | tokenizer.train() | 92 | tokenizer.train() |
@@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks( | |||
181 | torch.cuda.empty_cache() | 174 | torch.cuda.empty_cache() |
182 | 175 | ||
183 | return TrainingCallbacks( | 176 | return TrainingCallbacks( |
184 | on_prepare=on_prepare, | ||
185 | on_accum_model=on_accum_model, | 177 | on_accum_model=on_accum_model, |
186 | on_train=on_train, | 178 | on_train=on_train, |
187 | on_eval=on_eval, | 179 | on_eval=on_eval, |
@@ -203,7 +195,12 @@ def dreambooth_prepare( | |||
203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 195 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
204 | **kwargs | 196 | **kwargs |
205 | ): | 197 | ): |
206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
199 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
200 | |||
201 | text_encoder.text_model.embeddings.requires_grad_(False) | ||
202 | |||
203 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
207 | 204 | ||
208 | 205 | ||
209 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |