From 6720c99f7082dc855059ad4afd6b3cb45b62bc1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 1 Oct 2022 16:53:19 +0200 Subject: Fix seed, better progress bar, fix euler_a for batch size > 1 --- infer.py | 10 +++++++--- .../stable_diffusion/clip_guided_stable_diffusion.py | 4 ++-- schedulers/scheduling_euler_a.py | 16 ++++++---------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/infer.py b/infer.py index 40720ea..d917239 100644 --- a/infer.py +++ b/infer.py @@ -91,7 +91,7 @@ def create_cmd_parser(): parser.add_argument( "--seed", type=int, - default=torch.random.seed(), + default=None, ) parser.add_argument( "--config", @@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args): output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") output_dir.mkdir(parents=True, exist_ok=True) + seed = args.seed or torch.random.seed() + save_args(output_dir, args) with autocast("cuda"): for i in range(args.batch_num): - generator = torch.Generator(device="cuda").manual_seed(args.seed + i) + pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") + + generator = torch.Generator(device="cuda").manual_seed(seed + i) images = pipeline( prompt=[args.prompt] * args.batch_size, height=args.height, @@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) + image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) class CmdParse(cmd.Cmd): diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py index ddf7ce1..eff74b5 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py @@ -254,10 +254,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): noise_pred = None if isinstance(self.scheduler, EulerAScheduler): sigma = t.reshape(1) - sigma_in = torch.cat([sigma] * 2) + sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, - text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) + text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample else: # predict the noise residual diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 29ebd07..9fbedaa 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None): return append_zero(sigmas.flip(0)) t_max = len(sigmas) - 1 # = 999 t = torch.linspace(t_max, 0, n, device=sigmas.device) - # t = torch.linspace(t_max, 0, n, device=sigmas.device) return append_zero(t_to_sigma(t, sigmas)) # from k_samplers utils.py @@ -55,14 +54,15 @@ def append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): +def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): # x_in = torch.cat([x] * 2)#A# concat the latent # sigma_in = torch.cat([sigma] * 2) #A# concat sigma # cond_in = torch.cat([uncond, cond]) # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) # return uncond + (cond - uncond) * cond_scale - noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) + noise_pred = DiscreteEpsDDPMDenoiserForward( + Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) return noise_pred # from k_samplers sampling.py @@ -82,9 +82,7 @@ def get_scalings(sigma): # DiscreteSchedule DS -def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): - # quantize = self.quantize if quantize is None else quantize - quantize = False +def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): dists = torch.abs(sigma - DSsigmas[:, None]) if quantize: return torch.argmin(dists, dim=0).view(sigma.shape) @@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): return t.view(sigma.shape) -def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): +def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): sigma = sigma.to(Unet.device) DSsigmas = DSsigmas.to(Unet.device) c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] - # ??? what is eps? - # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) - eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), + eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), encoder_hidden_states=kwargs['cond']).sample return input + eps * c_out -- cgit v1.2.3-70-g09d2