summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 28fccff..9808027 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks(
74 power=ema_power, 74 power=ema_power,
75 max_value=ema_max_decay, 75 max_value=ema_max_decay,
76 ) 76 )
77 ema_unet.to(accelerator.device)
77 else: 78 else:
78 ema_unet = None 79 ema_unet = None
79 80
@@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks(
86 def on_accum_model(): 87 def on_accum_model():
87 return unet 88 return unet
88 89
89 def on_prepare():
90 unet.requires_grad_(True)
91 text_encoder.text_model.encoder.requires_grad_(True)
92 text_encoder.text_model.final_layer_norm.requires_grad_(True)
93
94 if ema_unet is not None:
95 ema_unet.to(accelerator.device)
96
97 @contextmanager 90 @contextmanager
98 def on_train(epoch: int): 91 def on_train(epoch: int):
99 tokenizer.train() 92 tokenizer.train()
@@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks(
181 torch.cuda.empty_cache() 174 torch.cuda.empty_cache()
182 175
183 return TrainingCallbacks( 176 return TrainingCallbacks(
184 on_prepare=on_prepare,
185 on_accum_model=on_accum_model, 177 on_accum_model=on_accum_model,
186 on_train=on_train, 178 on_train=on_train,
187 on_eval=on_eval, 179 on_eval=on_eval,
@@ -203,7 +195,12 @@ def dreambooth_prepare(
203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 195 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
204 **kwargs 196 **kwargs
205): 197):
206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) 198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
199 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
200
201 text_encoder.text_model.embeddings.requires_grad_(False)
202
203 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
207 204
208 205
209dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(