diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-09 12:42:21 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-09 12:42:21 +0200 |
| commit | 1eef9a946161fd06b0e72ec804c68f4f0e74b380 (patch) | |
| tree | b4a272b7240c25c0eef173dbfd193ba89a592929 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.gz textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.tar.bz2 textual-inversion-diff-1eef9a946161fd06b0e72ec804c68f4f0e74b380.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 7b61c45..48fc7f2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -15,14 +15,14 @@ from accelerate import Accelerator | |||
| 15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
| 18 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 19 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 20 | from PIL import Image | 19 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 20 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 22 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 25 | 23 | ||
| 24 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
| 27 | 27 | ||
| 28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
| @@ -334,7 +334,6 @@ class Checkpointer: | |||
| 334 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 334 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| 335 | ), | 335 | ), |
| 336 | ) | 336 | ) |
| 337 | pipeline.enable_attention_slicing() | ||
| 338 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 337 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
| 339 | 338 | ||
| 340 | del unwrapped | 339 | del unwrapped |
| @@ -359,7 +358,6 @@ class Checkpointer: | |||
| 359 | tokenizer=self.tokenizer, | 358 | tokenizer=self.tokenizer, |
| 360 | scheduler=scheduler, | 359 | scheduler=scheduler, |
| 361 | ).to(self.accelerator.device) | 360 | ).to(self.accelerator.device) |
| 362 | pipeline.enable_attention_slicing() | ||
| 363 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 361 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 364 | 362 | ||
| 365 | train_data = self.datamodule.train_dataloader() | 363 | train_data = self.datamodule.train_dataloader() |
| @@ -561,7 +559,6 @@ def main(): | |||
| 561 | tokenizer=tokenizer, | 559 | tokenizer=tokenizer, |
| 562 | scheduler=scheduler, | 560 | scheduler=scheduler, |
| 563 | ).to(accelerator.device) | 561 | ).to(accelerator.device) |
| 564 | pipeline.enable_attention_slicing() | ||
| 565 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 562 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 566 | 563 | ||
| 567 | with torch.inference_mode(): | 564 | with torch.inference_mode(): |
