diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 6 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 6 | ||||
| -rw-r--r-- | training/strategy/lora.py | 9 | ||||
| -rw-r--r-- | training/strategy/ti.py | 6 |
4 files changed, 8 insertions, 19 deletions
diff --git a/training/functional.py b/training/functional.py index c30d1c0..4d83df1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -34,7 +34,6 @@ def const(result=None): | |||
| 34 | 34 | ||
| 35 | @dataclass | 35 | @dataclass |
| 36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
| 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | ||
| 38 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[float, int], Any] = const() |
| @@ -461,7 +460,6 @@ def train_loop( | |||
| 461 | ) | 460 | ) |
| 462 | global_progress_bar.set_description("Total progress") | 461 | global_progress_bar.set_description("Total progress") |
| 463 | 462 | ||
| 464 | model = callbacks.on_accum_model() | ||
| 465 | on_log = callbacks.on_log | 463 | on_log = callbacks.on_log |
| 466 | on_train = callbacks.on_train | 464 | on_train = callbacks.on_train |
| 467 | on_before_optimize = callbacks.on_before_optimize | 465 | on_before_optimize = callbacks.on_before_optimize |
| @@ -498,8 +496,6 @@ def train_loop( | |||
| 498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 499 | local_progress_bar.reset() | 497 | local_progress_bar.reset() |
| 500 | 498 | ||
| 501 | model.train() | ||
| 502 | |||
| 503 | with on_train(epoch): | 499 | with on_train(epoch): |
| 504 | for step, batch in enumerate(train_dataloader): | 500 | for step, batch in enumerate(train_dataloader): |
| 505 | loss, acc, bsz = loss_step(step, batch, cache) | 501 | loss, acc, bsz = loss_step(step, batch, cache) |
| @@ -560,8 +556,6 @@ def train_loop( | |||
| 560 | on_after_epoch() | 556 | on_after_epoch() |
| 561 | 557 | ||
| 562 | if val_dataloader is not None: | 558 | if val_dataloader is not None: |
| 563 | model.eval() | ||
| 564 | |||
| 565 | cur_loss_val = AverageMeter() | 559 | cur_loss_val = AverageMeter() |
| 566 | cur_acc_val = AverageMeter() | 560 | cur_acc_val = AverageMeter() |
| 567 | 561 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 9808027..0286673 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks( | |||
| 84 | else: | 84 | else: |
| 85 | return nullcontext() | 85 | return nullcontext() |
| 86 | 86 | ||
| 87 | def on_accum_model(): | ||
| 88 | return unet | ||
| 89 | |||
| 90 | @contextmanager | 87 | @contextmanager |
| 91 | def on_train(epoch: int): | 88 | def on_train(epoch: int): |
| 89 | unet.train() | ||
| 92 | tokenizer.train() | 90 | tokenizer.train() |
| 93 | 91 | ||
| 94 | if epoch < train_text_encoder_epochs: | 92 | if epoch < train_text_encoder_epochs: |
| @@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks( | |||
| 101 | 99 | ||
| 102 | @contextmanager | 100 | @contextmanager |
| 103 | def on_eval(): | 101 | def on_eval(): |
| 102 | unet.eval() | ||
| 104 | tokenizer.eval() | 103 | tokenizer.eval() |
| 105 | text_encoder.eval() | 104 | text_encoder.eval() |
| 106 | 105 | ||
| @@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks( | |||
| 174 | torch.cuda.empty_cache() | 173 | torch.cuda.empty_cache() |
| 175 | 174 | ||
| 176 | return TrainingCallbacks( | 175 | return TrainingCallbacks( |
| 177 | on_accum_model=on_accum_model, | ||
| 178 | on_train=on_train, | 176 | on_train=on_train, |
| 179 | on_eval=on_eval, | 177 | on_eval=on_eval, |
| 180 | on_before_optimize=on_before_optimize, | 178 | on_before_optimize=on_before_optimize, |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 6730dc9..80ffa9c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -64,17 +64,17 @@ def lora_strategy_callbacks( | |||
| 64 | image_size=sample_image_size, | 64 | image_size=sample_image_size, |
| 65 | ) | 65 | ) |
| 66 | 66 | ||
| 67 | def on_accum_model(): | ||
| 68 | return unet | ||
| 69 | |||
| 70 | @contextmanager | 67 | @contextmanager |
| 71 | def on_train(epoch: int): | 68 | def on_train(epoch: int): |
| 72 | tokenizer.train() | 69 | unet.train() |
| 73 | text_encoder.train() | 70 | text_encoder.train() |
| 71 | tokenizer.train() | ||
| 74 | yield | 72 | yield |
| 75 | 73 | ||
| 76 | @contextmanager | 74 | @contextmanager |
| 77 | def on_eval(): | 75 | def on_eval(): |
| 76 | unet.eval() | ||
| 77 | text_encoder.eval() | ||
| 78 | tokenizer.eval() | 78 | tokenizer.eval() |
| 79 | yield | 79 | yield |
| 80 | 80 | ||
| @@ -152,7 +152,6 @@ def lora_strategy_callbacks( | |||
| 152 | torch.cuda.empty_cache() | 152 | torch.cuda.empty_cache() |
| 153 | 153 | ||
| 154 | return TrainingCallbacks( | 154 | return TrainingCallbacks( |
| 155 | on_accum_model=on_accum_model, | ||
| 156 | on_train=on_train, | 155 | on_train=on_train, |
| 157 | on_eval=on_eval, | 156 | on_eval=on_eval, |
| 158 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 55e9934..6a637c3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks( | |||
| 89 | else: | 89 | else: |
| 90 | return nullcontext() | 90 | return nullcontext() |
| 91 | 91 | ||
| 92 | def on_accum_model(): | ||
| 93 | return text_encoder.text_model.embeddings.token_override_embedding.params | ||
| 94 | |||
| 95 | @contextmanager | 92 | @contextmanager |
| 96 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
| 94 | text_encoder.text_model.embeddings.token_override_embedding.params.train() | ||
| 97 | tokenizer.train() | 95 | tokenizer.train() |
| 98 | yield | 96 | yield |
| 99 | 97 | ||
| 100 | @contextmanager | 98 | @contextmanager |
| 101 | def on_eval(): | 99 | def on_eval(): |
| 100 | text_encoder.text_model.embeddings.token_override_embedding.params.eval() | ||
| 102 | tokenizer.eval() | 101 | tokenizer.eval() |
| 103 | 102 | ||
| 104 | with ema_context(): | 103 | with ema_context(): |
| @@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks( | |||
| 166 | torch.cuda.empty_cache() | 165 | torch.cuda.empty_cache() |
| 167 | 166 | ||
| 168 | return TrainingCallbacks( | 167 | return TrainingCallbacks( |
| 169 | on_accum_model=on_accum_model, | ||
| 170 | on_train=on_train, | 168 | on_train=on_train, |
| 171 | on_eval=on_eval, | 169 | on_eval=on_eval, |
| 172 | on_before_optimize=on_before_optimize, | 170 | on_before_optimize=on_before_optimize, |
