From 4f6ecc2fd38dc894ac7fcb4f130d3ab8af49d132 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Jun 2023 08:40:05 +0200 Subject: Update --- training/functional.py | 3 --- training/strategy/dreambooth.py | 3 +++ training/strategy/ti.py | 7 +++++++ 3 files changed, 10 insertions(+), 3 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 8917eb7..b60afe3 100644 --- a/training/functional.py +++ b/training/functional.py @@ -786,7 +786,4 @@ def train( accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) - text_encoder.forward = MethodType(text_encoder.forward, text_encoder) - unet.forward = MethodType(unet.forward, unet) - accelerator.free_memory() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 3d1abf7..7e67589 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -154,6 +154,9 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + unet_.forward = MethodType(unet_.forward, unet_) + text_encoder_.text_model.embeddings.persist(False) with ema_context(): diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 7373982..f37dfb4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -1,4 +1,5 @@ from typing import Optional +from types import MethodType from functools import partial from contextlib import contextmanager, nullcontext from pathlib import Path @@ -139,6 +140,12 @@ def textual_inversion_strategy_callbacks( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") + if postfix == "end": + text_encoder_ = accelerator.unwrap_model( + text_encoder, keep_fp32_wrapper=False + ) + text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) + with ema_context(): for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( -- cgit v1.2.3-70-g09d2