diff options
-rw-r--r-- | train_dreambooth.py | 14 | ||||
-rw-r--r-- | training/functional.py | 12 | ||||
-rw-r--r-- | training/lr.py | 2 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 5 | ||||
-rw-r--r-- | training/strategy/ti.py | 14 |
5 files changed, 25 insertions, 22 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 48bdcf8..9c1e41c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -1,6 +1,7 @@ | |||
1 | import argparse | 1 | import argparse |
2 | import datetime | 2 | import datetime |
3 | import logging | 3 | import logging |
4 | import itertools | ||
4 | from pathlib import Path | 5 | from pathlib import Path |
5 | from functools import partial | 6 | from functools import partial |
6 | 7 | ||
@@ -578,14 +579,11 @@ def main(): | |||
578 | datamodule.setup() | 579 | datamodule.setup() |
579 | 580 | ||
580 | optimizer = optimizer_class( | 581 | optimizer = optimizer_class( |
581 | [ | 582 | itertools.chain( |
582 | { | 583 | unet.parameters(), |
583 | 'params': unet.parameters(), | 584 | text_encoder.text_model.encoder.parameters(), |
584 | }, | 585 | text_encoder.text_model.final_layer_norm.parameters(), |
585 | { | 586 | ), |
586 | 'params': text_encoder.parameters(), | ||
587 | } | ||
588 | ], | ||
589 | lr=args.learning_rate, | 587 | lr=args.learning_rate, |
590 | betas=(args.adam_beta1, args.adam_beta2), | 588 | betas=(args.adam_beta1, args.adam_beta2), |
591 | weight_decay=args.adam_weight_decay, | 589 | weight_decay=args.adam_weight_decay, |
diff --git a/training/functional.py b/training/functional.py index 7a3e821..a450ef6 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from dataclasses import dataclass | 1 | from dataclasses import dataclass |
2 | import math | 2 | import math |
3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
4 | from typing import Callable, Any, Tuple, Union, Optional, Type | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol |
5 | from functools import partial | 5 | from functools import partial |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | import itertools | 7 | import itertools |
@@ -37,7 +37,7 @@ class TrainingCallbacks(): | |||
37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_model: Callable[[], torch.nn.Module] = const(None) |
38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
41 | on_after_optimize: Callable[[float], None] = const() | 41 | on_after_optimize: Callable[[float], None] = const() |
42 | on_after_epoch: Callable[[float], None] = const() | 42 | on_after_epoch: Callable[[float], None] = const() |
43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
@@ -331,13 +331,17 @@ def loss_step( | |||
331 | return loss, acc, bsz | 331 | return loss, acc, bsz |
332 | 332 | ||
333 | 333 | ||
334 | class LossCallable(Protocol): | ||
335 | def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... | ||
336 | |||
337 | |||
334 | def train_loop( | 338 | def train_loop( |
335 | accelerator: Accelerator, | 339 | accelerator: Accelerator, |
336 | optimizer: torch.optim.Optimizer, | 340 | optimizer: torch.optim.Optimizer, |
337 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 341 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
338 | train_dataloader: DataLoader, | 342 | train_dataloader: DataLoader, |
339 | val_dataloader: Optional[DataLoader], | 343 | val_dataloader: Optional[DataLoader], |
340 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 344 | loss_step: LossCallable, |
341 | sample_frequency: int = 10, | 345 | sample_frequency: int = 10, |
342 | checkpoint_frequency: int = 50, | 346 | checkpoint_frequency: int = 50, |
343 | global_step_offset: int = 0, | 347 | global_step_offset: int = 0, |
@@ -406,7 +410,7 @@ def train_loop( | |||
406 | 410 | ||
407 | accelerator.backward(loss) | 411 | accelerator.backward(loss) |
408 | 412 | ||
409 | on_before_optimize(epoch) | 413 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
410 | 414 | ||
411 | optimizer.step() | 415 | optimizer.step() |
412 | lr_scheduler.step() | 416 | lr_scheduler.step() |
diff --git a/training/lr.py b/training/lr.py index 902c4eb..9690738 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -101,7 +101,7 @@ class LRFinder(): | |||
101 | 101 | ||
102 | self.accelerator.backward(loss) | 102 | self.accelerator.backward(loss) |
103 | 103 | ||
104 | on_before_optimize(epoch) | 104 | on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) |
105 | 105 | ||
106 | self.optimizer.step() | 106 | self.optimizer.step() |
107 | lr_scheduler.step() | 107 | lr_scheduler.step() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d813b49..f57e736 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks( | |||
99 | def on_prepare(): | 99 | def on_prepare(): |
100 | unet.requires_grad_(True) | 100 | unet.requires_grad_(True) |
101 | text_encoder.requires_grad_(True) | 101 | text_encoder.requires_grad_(True) |
102 | text_encoder.text_model.embeddings.persist() | 102 | text_encoder.text_model.embeddings.requires_grad_(False) |
103 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
104 | 103 | ||
105 | if ema_unet is not None: | 104 | if ema_unet is not None: |
106 | ema_unet.to(accelerator.device) | 105 | ema_unet.to(accelerator.device) |
@@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks( | |||
125 | with ema_context(): | 124 | with ema_context(): |
126 | yield | 125 | yield |
127 | 126 | ||
128 | def on_before_optimize(epoch: int): | 127 | def on_before_optimize(lr: float, epoch: int): |
129 | if accelerator.sync_gradients: | 128 | if accelerator.sync_gradients: |
130 | params_to_clip = [unet.parameters()] | 129 | params_to_clip = [unet.parameters()] |
131 | if epoch < train_text_encoder_epochs: | 130 | if epoch < train_text_encoder_epochs: |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ba78b98..e922954 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -117,14 +117,15 @@ def textual_inversion_strategy_callbacks( | |||
117 | with ema_context(): | 117 | with ema_context(): |
118 | yield | 118 | yield |
119 | 119 | ||
120 | def on_after_optimize(lr: float): | 120 | @torch.no_grad() |
121 | def on_before_optimize(lr: float, epoch: int): | ||
121 | if use_emb_decay: | 122 | if use_emb_decay: |
122 | with torch.no_grad(): | 123 | text_encoder.text_model.embeddings.normalize( |
123 | text_encoder.text_model.embeddings.normalize( | 124 | emb_decay_target, |
124 | emb_decay_target, | 125 | min(1.0, emb_decay * lr) |
125 | min(1.0, emb_decay * lr) | 126 | ) |
126 | ) | ||
127 | 127 | ||
128 | def on_after_optimize(lr: float): | ||
128 | if ema_embeddings is not None: | 129 | if ema_embeddings is not None: |
129 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 130 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
130 | 131 | ||
@@ -154,6 +155,7 @@ def textual_inversion_strategy_callbacks( | |||
154 | on_model=on_model, | 155 | on_model=on_model, |
155 | on_train=on_train, | 156 | on_train=on_train, |
156 | on_eval=on_eval, | 157 | on_eval=on_eval, |
158 | on_before_optimize=on_before_optimize, | ||
157 | on_after_optimize=on_after_optimize, | 159 | on_after_optimize=on_after_optimize, |
158 | on_log=on_log, | 160 | on_log=on_log, |
159 | on_checkpoint=on_checkpoint, | 161 | on_checkpoint=on_checkpoint, |