summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-09 12:42:21 +0200
committerVolpeon <git@volpeon.ink>2022-10-09 12:42:21 +0200
commit1eef9a946161fd06b0e72ec804c68f4f0e74b380 (patch)
treeb4a272b7240c25c0eef173dbfd193ba89a592929 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.gz
textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.bz2
textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 7b61c45..48fc7f2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -15,14 +15,14 @@ from accelerate import Accelerator
15from accelerate.logging import get_logger 15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
20from PIL import Image 19from PIL import Image
21from tqdm.auto import tqdm 20from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 21from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 22from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25 23
24from schedulers.scheduling_euler_a import EulerAScheduler
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import CSVDataModule 26from data.csv import CSVDataModule
27 27
28logger = get_logger(__name__) 28logger = get_logger(__name__)
@@ -334,7 +334,6 @@ class Checkpointer:
334 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 334 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
335 ), 335 ),
336 ) 336 )
337 pipeline.enable_attention_slicing()
338 pipeline.save_pretrained(self.output_dir.joinpath("model")) 337 pipeline.save_pretrained(self.output_dir.joinpath("model"))
339 338
340 del unwrapped 339 del unwrapped
@@ -359,7 +358,6 @@ class Checkpointer:
359 tokenizer=self.tokenizer, 358 tokenizer=self.tokenizer,
360 scheduler=scheduler, 359 scheduler=scheduler,
361 ).to(self.accelerator.device) 360 ).to(self.accelerator.device)
362 pipeline.enable_attention_slicing()
363 pipeline.set_progress_bar_config(dynamic_ncols=True) 361 pipeline.set_progress_bar_config(dynamic_ncols=True)
364 362
365 train_data = self.datamodule.train_dataloader() 363 train_data = self.datamodule.train_dataloader()
@@ -561,7 +559,6 @@ def main():
561 tokenizer=tokenizer, 559 tokenizer=tokenizer,
562 scheduler=scheduler, 560 scheduler=scheduler,
563 ).to(accelerator.device) 561 ).to(accelerator.device)
564 pipeline.enable_attention_slicing()
565 pipeline.set_progress_bar_config(dynamic_ncols=True) 562 pipeline.set_progress_bar_config(dynamic_ncols=True)
566 563
567 with torch.inference_mode(): 564 with torch.inference_mode():