From 0f493e1ac8406de061861ed390f283e821180e79 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 11:26:31 +0200 Subject: Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion --- dreambooth.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 4d7366c..744d1bc 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -14,12 +14,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from pipelines.stable_diffusion.no_check import NoCheck from PIL import Image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json from data.dreambooth.csv import CSVDataModule @@ -215,7 +217,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=80, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -377,15 +379,16 @@ class Checkpointer: samples_path = Path(self.output_dir).joinpath("samples") unwrapped = self.accelerator.unwrap_model(self.unet) - pipeline = StableDiffusionPipeline( + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( text_encoder=self.text_encoder, vae=self.vae, unet=unwrapped, tokenizer=self.tokenizer, - scheduler=LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ), - safety_checker=NoCheck(), + scheduler=scheduler, feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(self.accelerator.device) pipeline.enable_attention_slicing() @@ -411,6 +414,8 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) + with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -420,10 +425,13 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + generator=generator, output_type='pil' )["sample"] all_samples += samples + + del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) @@ -444,6 +452,8 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + generator = torch.Generator(device="cuda").manual_seed(self.seed + i) + with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -452,10 +462,13 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, + generator=generator, output_type='pil' )["sample"] all_samples += samples + + del generator del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) @@ -465,6 +478,7 @@ class Checkpointer: del image_grid del unwrapped + del scheduler del pipeline if torch.cuda.is_available(): -- cgit v1.2.3-54-g00ecf