diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/dreambooth.py b/dreambooth.py index 7b61c45..48fc7f2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -15,14 +15,14 @@ from accelerate import Accelerator | |||
15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
18 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
19 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
20 | from PIL import Image | 19 | from PIL import Image |
21 | from tqdm.auto import tqdm | 20 | from tqdm.auto import tqdm |
22 | from transformers import CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 22 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
25 | 23 | ||
24 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
27 | 27 | ||
28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
@@ -334,7 +334,6 @@ class Checkpointer: | |||
334 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 334 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
335 | ), | 335 | ), |
336 | ) | 336 | ) |
337 | pipeline.enable_attention_slicing() | ||
338 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 337 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
339 | 338 | ||
340 | del unwrapped | 339 | del unwrapped |
@@ -359,7 +358,6 @@ class Checkpointer: | |||
359 | tokenizer=self.tokenizer, | 358 | tokenizer=self.tokenizer, |
360 | scheduler=scheduler, | 359 | scheduler=scheduler, |
361 | ).to(self.accelerator.device) | 360 | ).to(self.accelerator.device) |
362 | pipeline.enable_attention_slicing() | ||
363 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 361 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
364 | 362 | ||
365 | train_data = self.datamodule.train_dataloader() | 363 | train_data = self.datamodule.train_dataloader() |
@@ -561,7 +559,6 @@ def main(): | |||
561 | tokenizer=tokenizer, | 559 | tokenizer=tokenizer, |
562 | scheduler=scheduler, | 560 | scheduler=scheduler, |
563 | ).to(accelerator.device) | 561 | ).to(accelerator.device) |
564 | pipeline.enable_attention_slicing() | ||
565 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 562 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
566 | 563 | ||
567 | with torch.inference_mode(): | 564 | with torch.inference_mode(): |