diff options
author | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
commit | 21d70916f66e74a87c631a06b70774954b085b48 (patch) | |
tree | d1b443b9270f45ae6936f3acb565f767c7c65b1f /training/strategy/lora.py | |
parent | Run PTI only if placeholder tokens arg isn't empty (diff) | |
download | textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2 textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip |
Fix
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r-- | training/strategy/lora.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 6730dc9..80ffa9c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -64,17 +64,17 @@ def lora_strategy_callbacks( | |||
64 | image_size=sample_image_size, | 64 | image_size=sample_image_size, |
65 | ) | 65 | ) |
66 | 66 | ||
67 | def on_accum_model(): | ||
68 | return unet | ||
69 | |||
70 | @contextmanager | 67 | @contextmanager |
71 | def on_train(epoch: int): | 68 | def on_train(epoch: int): |
72 | tokenizer.train() | 69 | unet.train() |
73 | text_encoder.train() | 70 | text_encoder.train() |
71 | tokenizer.train() | ||
74 | yield | 72 | yield |
75 | 73 | ||
76 | @contextmanager | 74 | @contextmanager |
77 | def on_eval(): | 75 | def on_eval(): |
76 | unet.eval() | ||
77 | text_encoder.eval() | ||
78 | tokenizer.eval() | 78 | tokenizer.eval() |
79 | yield | 79 | yield |
80 | 80 | ||
@@ -152,7 +152,6 @@ def lora_strategy_callbacks( | |||
152 | torch.cuda.empty_cache() | 152 | torch.cuda.empty_cache() |
153 | 153 | ||
154 | return TrainingCallbacks( | 154 | return TrainingCallbacks( |
155 | on_accum_model=on_accum_model, | ||
156 | on_train=on_train, | 155 | on_train=on_train, |
157 | on_eval=on_eval, | 156 | on_eval=on_eval, |
158 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |