diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-09 12:42:21 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-09 12:42:21 +0200 |
| commit | 1eef9a946161fd06b0e72ec804c68f4f0e74b380 (patch) | |
| tree | b4a272b7240c25c0eef173dbfd193ba89a592929 /textual_inversion.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.gz textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.bz2 textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.zip | |
Update
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 09871d4..e641cab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -16,14 +16,14 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 19 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 20 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 21 | from PIL import Image | 20 | from PIL import Image |
| 22 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 24 | from slugify import slugify | 23 | from slugify import slugify |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 26 | 24 | ||
| 25 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 28 | 28 | ||
| 29 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
| @@ -388,7 +388,6 @@ class Checkpointer: | |||
| 388 | tokenizer=self.tokenizer, | 388 | tokenizer=self.tokenizer, |
| 389 | scheduler=scheduler, | 389 | scheduler=scheduler, |
| 390 | ).to(self.accelerator.device) | 390 | ).to(self.accelerator.device) |
| 391 | pipeline.enable_attention_slicing() | ||
| 392 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 391 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 393 | 392 | ||
| 394 | train_data = self.datamodule.train_dataloader() | 393 | train_data = self.datamodule.train_dataloader() |
| @@ -518,8 +517,8 @@ def main(): | |||
| 518 | if args.gradient_checkpointing: | 517 | if args.gradient_checkpointing: |
| 519 | unet.enable_gradient_checkpointing() | 518 | unet.enable_gradient_checkpointing() |
| 520 | 519 | ||
| 521 | slice_size = unet.config.attention_head_dim // 2 | 520 | # slice_size = unet.config.attention_head_dim // 2 |
| 522 | unet.set_attention_slice(slice_size) | 521 | # unet.set_attention_slice(slice_size) |
| 523 | 522 | ||
| 524 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 523 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 525 | text_encoder.resize_token_embeddings(len(tokenizer)) | 524 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| @@ -639,7 +638,6 @@ def main(): | |||
| 639 | tokenizer=tokenizer, | 638 | tokenizer=tokenizer, |
| 640 | scheduler=scheduler, | 639 | scheduler=scheduler, |
| 641 | ).to(accelerator.device) | 640 | ).to(accelerator.device) |
| 642 | pipeline.enable_attention_slicing() | ||
| 643 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 641 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 644 | 642 | ||
| 645 | with torch.inference_mode(): | 643 | with torch.inference_mode(): |
