summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py14
-rw-r--r--training/functional.py12
-rw-r--r--training/lr.py2
-rw-r--r--training/strategy/dreambooth.py5
-rw-r--r--training/strategy/ti.py14
5 files changed, 25 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48bdcf8..9c1e41c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -1,6 +1,7 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
4from pathlib import Path 5from pathlib import Path
5from functools import partial 6from functools import partial
6 7
@@ -578,14 +579,11 @@ def main():
578 datamodule.setup() 579 datamodule.setup()
579 580
580 optimizer = optimizer_class( 581 optimizer = optimizer_class(
581 [ 582 itertools.chain(
582 { 583 unet.parameters(),
583 'params': unet.parameters(), 584 text_encoder.text_model.encoder.parameters(),
584 }, 585 text_encoder.text_model.final_layer_norm.parameters(),
585 { 586 ),
586 'params': text_encoder.parameters(),
587 }
588 ],
589 lr=args.learning_rate, 587 lr=args.learning_rate,
590 betas=(args.adam_beta1, args.adam_beta2), 588 betas=(args.adam_beta1, args.adam_beta2),
591 weight_decay=args.adam_weight_decay, 589 weight_decay=args.adam_weight_decay,
diff --git a/training/functional.py b/training/functional.py
index 7a3e821..a450ef6 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,7 +1,7 @@
1from dataclasses import dataclass 1from dataclasses import dataclass
2import math 2import math
3from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Optional, Type 4from typing import Callable, Any, Tuple, Union, Optional, Protocol
5from functools import partial 5from functools import partial
6from pathlib import Path 6from pathlib import Path
7import itertools 7import itertools
@@ -37,7 +37,7 @@ class TrainingCallbacks():
37 on_model: Callable[[], torch.nn.Module] = const(None) 37 on_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[int], None] = const() 40 on_before_optimize: Callable[[float, int], None] = const()
41 on_after_optimize: Callable[[float], None] = const() 41 on_after_optimize: Callable[[float], None] = const()
42 on_after_epoch: Callable[[float], None] = const() 42 on_after_epoch: Callable[[float], None] = const()
43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
@@ -331,13 +331,17 @@ def loss_step(
331 return loss, acc, bsz 331 return loss, acc, bsz
332 332
333 333
334class LossCallable(Protocol):
335 def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ...
336
337
334def train_loop( 338def train_loop(
335 accelerator: Accelerator, 339 accelerator: Accelerator,
336 optimizer: torch.optim.Optimizer, 340 optimizer: torch.optim.Optimizer,
337 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 341 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
338 train_dataloader: DataLoader, 342 train_dataloader: DataLoader,
339 val_dataloader: Optional[DataLoader], 343 val_dataloader: Optional[DataLoader],
340 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 344 loss_step: LossCallable,
341 sample_frequency: int = 10, 345 sample_frequency: int = 10,
342 checkpoint_frequency: int = 50, 346 checkpoint_frequency: int = 50,
343 global_step_offset: int = 0, 347 global_step_offset: int = 0,
@@ -406,7 +410,7 @@ def train_loop(
406 410
407 accelerator.backward(loss) 411 accelerator.backward(loss)
408 412
409 on_before_optimize(epoch) 413 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
410 414
411 optimizer.step() 415 optimizer.step()
412 lr_scheduler.step() 416 lr_scheduler.step()
diff --git a/training/lr.py b/training/lr.py
index 902c4eb..9690738 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -101,7 +101,7 @@ class LRFinder():
101 101
102 self.accelerator.backward(loss) 102 self.accelerator.backward(loss)
103 103
104 on_before_optimize(epoch) 104 on_before_optimize(lr_scheduler.get_last_lr()[0], epoch)
105 105
106 self.optimizer.step() 106 self.optimizer.step()
107 lr_scheduler.step() 107 lr_scheduler.step()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index d813b49..f57e736 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks(
99 def on_prepare(): 99 def on_prepare():
100 unet.requires_grad_(True) 100 unet.requires_grad_(True)
101 text_encoder.requires_grad_(True) 101 text_encoder.requires_grad_(True)
102 text_encoder.text_model.embeddings.persist() 102 text_encoder.text_model.embeddings.requires_grad_(False)
103 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
104 103
105 if ema_unet is not None: 104 if ema_unet is not None:
106 ema_unet.to(accelerator.device) 105 ema_unet.to(accelerator.device)
@@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks(
125 with ema_context(): 124 with ema_context():
126 yield 125 yield
127 126
128 def on_before_optimize(epoch: int): 127 def on_before_optimize(lr: float, epoch: int):
129 if accelerator.sync_gradients: 128 if accelerator.sync_gradients:
130 params_to_clip = [unet.parameters()] 129 params_to_clip = [unet.parameters()]
131 if epoch < train_text_encoder_epochs: 130 if epoch < train_text_encoder_epochs:
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index ba78b98..e922954 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -117,14 +117,15 @@ def textual_inversion_strategy_callbacks(
117 with ema_context(): 117 with ema_context():
118 yield 118 yield
119 119
120 def on_after_optimize(lr: float): 120 @torch.no_grad()
121 def on_before_optimize(lr: float, epoch: int):
121 if use_emb_decay: 122 if use_emb_decay:
122 with torch.no_grad(): 123 text_encoder.text_model.embeddings.normalize(
123 text_encoder.text_model.embeddings.normalize( 124 emb_decay_target,
124 emb_decay_target, 125 min(1.0, emb_decay * lr)
125 min(1.0, emb_decay * lr) 126 )
126 )
127 127
128 def on_after_optimize(lr: float):
128 if ema_embeddings is not None: 129 if ema_embeddings is not None:
129 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 130 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
130 131
@@ -154,6 +155,7 @@ def textual_inversion_strategy_callbacks(
154 on_model=on_model, 155 on_model=on_model,
155 on_train=on_train, 156 on_train=on_train,
156 on_eval=on_eval, 157 on_eval=on_eval,
158 on_before_optimize=on_before_optimize,
157 on_after_optimize=on_after_optimize, 159 on_after_optimize=on_after_optimize,
158 on_log=on_log, 160 on_log=on_log,
159 on_checkpoint=on_checkpoint, 161 on_checkpoint=on_checkpoint,