summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-09 12:42:21 +0200
committerVolpeon <git@volpeon.ink>2022-10-09 12:42:21 +0200
commit1eef9a946161fd06b0e72ec804c68f4f0e74b380 (patch)
treeb4a272b7240c25c0eef173dbfd193ba89a592929 /textual_inversion.py
parentUpdate (diff)
downloadtextual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.gz
textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.bz2
textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 09871d4..e641cab 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -16,14 +16,14 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from schedulers.scheduling_euler_a import EulerAScheduler
20from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
21from PIL import Image 20from PIL import Image
22from tqdm.auto import tqdm 21from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 23from slugify import slugify
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26 24
25from schedulers.scheduling_euler_a import EulerAScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28 28
29logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -388,7 +388,6 @@ class Checkpointer:
388 tokenizer=self.tokenizer, 388 tokenizer=self.tokenizer,
389 scheduler=scheduler, 389 scheduler=scheduler,
390 ).to(self.accelerator.device) 390 ).to(self.accelerator.device)
391 pipeline.enable_attention_slicing()
392 pipeline.set_progress_bar_config(dynamic_ncols=True) 391 pipeline.set_progress_bar_config(dynamic_ncols=True)
393 392
394 train_data = self.datamodule.train_dataloader() 393 train_data = self.datamodule.train_dataloader()
@@ -518,8 +517,8 @@ def main():
518 if args.gradient_checkpointing: 517 if args.gradient_checkpointing:
519 unet.enable_gradient_checkpointing() 518 unet.enable_gradient_checkpointing()
520 519
521 slice_size = unet.config.attention_head_dim // 2 520 # slice_size = unet.config.attention_head_dim // 2
522 unet.set_attention_slice(slice_size) 521 # unet.set_attention_slice(slice_size)
523 522
524 # Resize the token embeddings as we are adding new special tokens to the tokenizer 523 # Resize the token embeddings as we are adding new special tokens to the tokenizer
525 text_encoder.resize_token_embeddings(len(tokenizer)) 524 text_encoder.resize_token_embeddings(len(tokenizer))
@@ -639,7 +638,6 @@ def main():
639 tokenizer=tokenizer, 638 tokenizer=tokenizer,
640 scheduler=scheduler, 639 scheduler=scheduler,
641 ).to(accelerator.device) 640 ).to(accelerator.device)
642 pipeline.enable_attention_slicing()
643 pipeline.set_progress_bar_config(dynamic_ncols=True) 641 pipeline.set_progress_bar_config(dynamic_ncols=True)
644 642
645 with torch.inference_mode(): 643 with torch.inference_mode():