diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 00d460f..5fc2338 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -14,7 +14,7 @@ import torch.utils.checkpoint | |||
14 | from accelerate import Accelerator | 14 | from accelerate import Accelerator |
15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
17 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
18 | from schedulers.scheduling_euler_a import EulerAScheduler | 18 | from schedulers.scheduling_euler_a import EulerAScheduler |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
20 | from PIL import Image | 20 | from PIL import Image |
@@ -30,6 +30,9 @@ from data.textual_inversion.csv import CSVDataModule | |||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
31 | 31 | ||
32 | 32 | ||
33 | torch.backends.cuda.matmul.allow_tf32 = True | ||
34 | |||
35 | |||
33 | def parse_args(): | 36 | def parse_args(): |
34 | parser = argparse.ArgumentParser( | 37 | parser = argparse.ArgumentParser( |
35 | description="Simple example of a training script." | 38 | description="Simple example of a training script." |
@@ -370,7 +373,6 @@ class Checkpointer: | |||
370 | unet=self.unet, | 373 | unet=self.unet, |
371 | tokenizer=self.tokenizer, | 374 | tokenizer=self.tokenizer, |
372 | scheduler=scheduler, | 375 | scheduler=scheduler, |
373 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
374 | ).to(self.accelerator.device) | 376 | ).to(self.accelerator.device) |
375 | pipeline.enable_attention_slicing() | 377 | pipeline.enable_attention_slicing() |
376 | 378 | ||