summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-25 08:40:05 +0200
committerVolpeon <git@volpeon.ink>2023-06-25 08:40:05 +0200
commit4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 (patch)
treec6ab59d6726b818638fe90a3ea8bb8403a8b0c30 /training
parentUpdate (diff)
downloadtextual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.tar.gz
textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.tar.bz2
textual-inversion-diff-4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132.zip
Update
Diffstat (limited to 'training')
-rw-r--r--training/functional.py3
-rw-r--r--training/strategy/dreambooth.py3
-rw-r--r--training/strategy/ti.py7
3 files changed, 10 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index 8917eb7..b60afe3 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -786,7 +786,4 @@ def train(
786 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 786 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
787 accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 787 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
788 788
789 text_encoder.forward = MethodType(text_encoder.forward, text_encoder)
790 unet.forward = MethodType(unet.forward, unet)
791
792 accelerator.free_memory() 789 accelerator.free_memory()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 3d1abf7..7e67589 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks(
154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
156 156
157 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
158 unet_.forward = MethodType(unet_.forward, unet_)
159
157 text_encoder_.text_model.embeddings.persist(False) 160 text_encoder_.text_model.embeddings.persist(False)
158 161
159 with ema_context(): 162 with ema_context():
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 7373982..f37dfb4 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks(
139 def on_checkpoint(step, postfix): 140 def on_checkpoint(step, postfix):
140 print(f"Saving checkpoint for step {step}...") 141 print(f"Saving checkpoint for step {step}...")
141 142
143 if postfix == "end":
144 text_encoder_ = accelerator.unwrap_model(
145 text_encoder, keep_fp32_wrapper=False
146 )
147 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
148
142 with ema_context(): 149 with ema_context():
143 for token, ids in zip(placeholder_tokens, placeholder_token_ids): 150 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
144 text_encoder.text_model.embeddings.save_embed( 151 text_encoder.text_model.embeddings.save_embed(