diff options
author | Volpeon <git@volpeon.ink> | 2022-10-09 12:42:21 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-09 12:42:21 +0200 |
commit | 1eef9a946161fd06b0e72ec804c68f4f0e74b380 (patch) | |
tree | b4a272b7240c25c0eef173dbfd193ba89a592929 /textual_inversion.py | |
parent | Update (diff) | |
download | textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.gz textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.bz2 textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.zip |
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 09871d4..e641cab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -16,14 +16,14 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
19 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
20 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
21 | from PIL import Image | 20 | from PIL import Image |
22 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
23 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
24 | from slugify import slugify | 23 | from slugify import slugify |
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
26 | 24 | ||
25 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
28 | 28 | ||
29 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
@@ -388,7 +388,6 @@ class Checkpointer: | |||
388 | tokenizer=self.tokenizer, | 388 | tokenizer=self.tokenizer, |
389 | scheduler=scheduler, | 389 | scheduler=scheduler, |
390 | ).to(self.accelerator.device) | 390 | ).to(self.accelerator.device) |
391 | pipeline.enable_attention_slicing() | ||
392 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 391 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
393 | 392 | ||
394 | train_data = self.datamodule.train_dataloader() | 393 | train_data = self.datamodule.train_dataloader() |
@@ -518,8 +517,8 @@ def main(): | |||
518 | if args.gradient_checkpointing: | 517 | if args.gradient_checkpointing: |
519 | unet.enable_gradient_checkpointing() | 518 | unet.enable_gradient_checkpointing() |
520 | 519 | ||
521 | slice_size = unet.config.attention_head_dim // 2 | 520 | # slice_size = unet.config.attention_head_dim // 2 |
522 | unet.set_attention_slice(slice_size) | 521 | # unet.set_attention_slice(slice_size) |
523 | 522 | ||
524 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 523 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
525 | text_encoder.resize_token_embeddings(len(tokenizer)) | 524 | text_encoder.resize_token_embeddings(len(tokenizer)) |
@@ -639,7 +638,6 @@ def main(): | |||
639 | tokenizer=tokenizer, | 638 | tokenizer=tokenizer, |
640 | scheduler=scheduler, | 639 | scheduler=scheduler, |
641 | ).to(accelerator.device) | 640 | ).to(accelerator.device) |
642 | pipeline.enable_attention_slicing() | ||
643 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 641 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
644 | 642 | ||
645 | with torch.inference_mode(): | 643 | with torch.inference_mode(): |