From 1eef9a946161fd06b0e72ec804c68f4f0e74b380 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Oct 2022 12:42:21 +0200 Subject: Update --- dreambooth.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 7b61c45..48fc7f2 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -15,14 +15,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel -from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from schedulers.scheduling_euler_a import EulerAScheduler +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule logger = get_logger(__name__) @@ -334,7 +334,6 @@ class Checkpointer: beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), ) - pipeline.enable_attention_slicing() pipeline.save_pretrained(self.output_dir.joinpath("model")) del unwrapped @@ -359,7 +358,6 @@ class Checkpointer: tokenizer=self.tokenizer, scheduler=scheduler, ).to(self.accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() @@ -561,7 +559,6 @@ def main(): tokenizer=tokenizer, scheduler=scheduler, ).to(accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): -- cgit v1.2.3-54-g00ecf