From 842f26654bbe7dfd2f45df1fd2660d3f902af8cc Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 17 Feb 2023 14:53:25 +0100 Subject: Remove xformers, switch to Pytorch Nightly --- training/functional.py | 3 +-- training/strategy/dreambooth.py | 8 ++++---- training/strategy/lora.py | 2 +- training/strategy/ti.py | 4 ++-- 4 files changed, 8 insertions(+), 9 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 78a2b10..41794ea 100644 --- a/training/functional.py +++ b/training/functional.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, UniPCMultistepScheduler from tqdm.auto import tqdm from PIL import Image @@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer -from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from training.util import AverageMeter diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 8aaed3a..d697554 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -144,8 +144,8 @@ def dreambooth_strategy_callbacks( print("Saving model...") - unet_ = accelerator.unwrap_model(unet) - text_encoder_ = accelerator.unwrap_model(text_encoder) + unet_ = accelerator.unwrap_model(unet, False) + text_encoder_ = accelerator.unwrap_model(text_encoder, False) with ema_context(): pipeline = VlpnStableDiffusion( @@ -167,8 +167,8 @@ def dreambooth_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - unet_ = accelerator.unwrap_model(unet) - text_encoder_ = accelerator.unwrap_model(text_encoder) + unet_ = accelerator.unwrap_model(unet, False) + text_encoder_ = accelerator.unwrap_model(text_encoder, False) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 4dd1100..ccec215 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -90,7 +90,7 @@ def lora_strategy_callbacks( def on_checkpoint(step, postfix): print(f"Saving checkpoint for step {step}...") - unet_ = accelerator.unwrap_model(unet) + unet_ = accelerator.unwrap_model(unet, False) unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") del unet_ diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 0de3cb0..66d3129 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -144,8 +144,8 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_sample(step): with ema_context(): - unet_ = accelerator.unwrap_model(unet) - text_encoder_ = accelerator.unwrap_model(text_encoder) + unet_ = accelerator.unwrap_model(unet, False) + text_encoder_ = accelerator.unwrap_model(text_encoder, False) orig_unet_dtype = unet_.dtype orig_text_encoder_dtype = text_encoder_.dtype -- cgit v1.2.3-54-g00ecf