From f23fd5184b8ba4ec04506495f4a61726e50756f7 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 17:38:44 +0200 Subject: Small perf improvements --- textual_inversion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 00d460f..5fc2338 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -14,7 +14,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from PIL import Image @@ -30,6 +30,9 @@ from data.textual_inversion.csv import CSVDataModule logger = get_logger(__name__) +torch.backends.cuda.matmul.allow_tf32 = True + + def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." @@ -370,7 +373,6 @@ class Checkpointer: unet=self.unet, tokenizer=self.tokenizer, scheduler=scheduler, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() -- cgit v1.2.3-54-g00ecf