From 1eef9a946161fd06b0e72ec804c68f4f0e74b380 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Oct 2022 12:42:21 +0200 Subject: Update --- textual_inversion.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 09871d4..e641cab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -16,14 +16,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel -from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from schedulers.scheduling_euler_a import EulerAScheduler +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule logger = get_logger(__name__) @@ -388,7 +388,6 @@ class Checkpointer: tokenizer=self.tokenizer, scheduler=scheduler, ).to(self.accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() @@ -518,8 +517,8 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - slice_size = unet.config.attention_head_dim // 2 - unet.set_attention_slice(slice_size) + # slice_size = unet.config.attention_head_dim // 2 + # unet.set_attention_slice(slice_size) # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -639,7 +638,6 @@ def main(): tokenizer=tokenizer, scheduler=scheduler, ).to(accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): -- cgit v1.2.3-54-g00ecf