From 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 08:40:05 +0200 Subject: Update --- training/strategy/ti.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'training/strategy/ti.py') diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -1,4 +1,5 @@ from typing import Optional +from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path @@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") + if postfix == "end": + text_encoder_ = accelerator.unwrap_model( + text_encoder, keep_fp32_wrapper=False + ) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + with ema_context(): for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( -- cgit v1.2.3-54-g00ecf