summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /training/strategy/dreambooth.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py29
1 files changed, 17 insertions, 12 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e6fcc89..88b441b 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks(
29 sample_output_dir: Path, 29 sample_output_dir: Path,
30 checkpoint_output_dir: Path, 30 checkpoint_output_dir: Path,
31 seed: int, 31 seed: int,
32 train_text_encoder_epochs: int, 32 train_text_encoder_cycles: int,
33 max_grad_norm: float = 1.0, 33 max_grad_norm: float = 1.0,
34 use_ema: bool = False, 34 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 35 ema_inv_gamma: float = 1.0,
@@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks(
85 return nullcontext() 85 return nullcontext()
86 86
87 @contextmanager 87 @contextmanager
88 def on_train(epoch: int): 88 def on_train(cycle: int):
89 unet.train() 89 unet.train()
90 tokenizer.train() 90 tokenizer.train()
91 91
92 if epoch < train_text_encoder_epochs: 92 if cycle < train_text_encoder_cycles:
93 text_encoder.train() 93 text_encoder.train()
94 elif epoch == train_text_encoder_epochs: 94 tokenizer.train()
95 text_encoder.requires_grad_(False)
96 text_encoder.eval()
97 95
98 yield 96 yield
99 97
@@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 104 with ema_context():
107 yield 105 yield
108 106
109 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
110 params_to_clip = [unet.parameters()] 108 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 109 if cycle < train_text_encoder_cycles:
112 params_to_clip.append(text_encoder.parameters()) 110 params_to_clip.append(text_encoder.parameters())
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 111 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 112
@@ -189,8 +187,16 @@ def dreambooth_prepare(
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs 188 **kwargs
191): 189):
192 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 190 (
193 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 191 text_encoder,
192 unet,
193 optimizer,
194 train_dataloader,
195 val_dataloader,
196 lr_scheduler,
197 ) = accelerator.prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 )
194 200
195 text_encoder.text_model.embeddings.requires_grad_(False) 201 text_encoder.text_model.embeddings.requires_grad_(False)
196 202
@@ -198,6 +204,5 @@ def dreambooth_prepare(
198 204
199 205
200dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(
201 callbacks=dreambooth_strategy_callbacks, 207 callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare
202 prepare=dreambooth_prepare
203) 208)