diff options
| -rw-r--r-- | infer.py | 10 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 4 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 16 |
3 files changed, 15 insertions, 15 deletions
| @@ -91,7 +91,7 @@ def create_cmd_parser(): | |||
| 91 | parser.add_argument( | 91 | parser.add_argument( |
| 92 | "--seed", | 92 | "--seed", |
| 93 | type=int, | 93 | type=int, |
| 94 | default=torch.random.seed(), | 94 | default=None, |
| 95 | ) | 95 | ) |
| 96 | parser.add_argument( | 96 | parser.add_argument( |
| 97 | "--config", | 97 | "--config", |
| @@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args): | |||
| 167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") |
| 168 | output_dir.mkdir(parents=True, exist_ok=True) | 168 | output_dir.mkdir(parents=True, exist_ok=True) |
| 169 | 169 | ||
| 170 | seed = args.seed or torch.random.seed() | ||
| 171 | |||
| 170 | save_args(output_dir, args) | 172 | save_args(output_dir, args) |
| 171 | 173 | ||
| 172 | with autocast("cuda"): | 174 | with autocast("cuda"): |
| 173 | for i in range(args.batch_num): | 175 | for i in range(args.batch_num): |
| 174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 176 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") |
| 177 | |||
| 178 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | ||
| 175 | images = pipeline( | 179 | images = pipeline( |
| 176 | prompt=[args.prompt] * args.batch_size, | 180 | prompt=[args.prompt] * args.batch_size, |
| 177 | height=args.height, | 181 | height=args.height, |
| @@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args): | |||
| 183 | ).images | 187 | ).images |
| 184 | 188 | ||
| 185 | for j, image in enumerate(images): | 189 | for j, image in enumerate(images): |
| 186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 190 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
| 187 | 191 | ||
| 188 | 192 | ||
| 189 | class CmdParse(cmd.Cmd): | 193 | class CmdParse(cmd.Cmd): |
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py index ddf7ce1..eff74b5 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
| @@ -254,10 +254,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
| 254 | noise_pred = None | 254 | noise_pred = None |
| 255 | if isinstance(self.scheduler, EulerAScheduler): | 255 | if isinstance(self.scheduler, EulerAScheduler): |
| 256 | sigma = t.reshape(1) | 256 | sigma = t.reshape(1) |
| 257 | sigma_in = torch.cat([sigma] * 2) | 257 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) |
| 258 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | 258 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) |
| 259 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 259 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, |
| 260 | text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) | 260 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) |
| 261 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | 261 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample |
| 262 | else: | 262 | else: |
| 263 | # predict the noise residual | 263 | # predict the noise residual |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 29ebd07..9fbedaa 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None): | |||
| 41 | return append_zero(sigmas.flip(0)) | 41 | return append_zero(sigmas.flip(0)) |
| 42 | t_max = len(sigmas) - 1 # = 999 | 42 | t_max = len(sigmas) - 1 # = 999 |
| 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) |
| 44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 45 | return append_zero(t_to_sigma(t, sigmas)) | 44 | return append_zero(t_to_sigma(t, sigmas)) |
| 46 | 45 | ||
| 47 | # from k_samplers utils.py | 46 | # from k_samplers utils.py |
| @@ -55,14 +54,15 @@ def append_dims(x, target_dims): | |||
| 55 | return x[(...,) + (None,) * dims_to_append] | 54 | return x[(...,) + (None,) * dims_to_append] |
| 56 | 55 | ||
| 57 | 56 | ||
| 58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | 57 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): |
| 59 | # x_in = torch.cat([x] * 2)#A# concat the latent | 58 | # x_in = torch.cat([x] * 2)#A# concat the latent |
| 60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | 59 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma |
| 61 | # cond_in = torch.cat([uncond, cond]) | 60 | # cond_in = torch.cat([uncond, cond]) |
| 62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | 61 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) |
| 63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | 62 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) |
| 64 | # return uncond + (cond - uncond) * cond_scale | 63 | # return uncond + (cond - uncond) * cond_scale |
| 65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | 64 | noise_pred = DiscreteEpsDDPMDenoiserForward( |
| 65 | Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) | ||
| 66 | return noise_pred | 66 | return noise_pred |
| 67 | 67 | ||
| 68 | # from k_samplers sampling.py | 68 | # from k_samplers sampling.py |
| @@ -82,9 +82,7 @@ def get_scalings(sigma): | |||
| 82 | # DiscreteSchedule DS | 82 | # DiscreteSchedule DS |
| 83 | 83 | ||
| 84 | 84 | ||
| 85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | 85 | def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): |
| 86 | # quantize = self.quantize if quantize is None else quantize | ||
| 87 | quantize = False | ||
| 88 | dists = torch.abs(sigma - DSsigmas[:, None]) | 86 | dists = torch.abs(sigma - DSsigmas[:, None]) |
| 89 | if quantize: | 87 | if quantize: |
| 90 | return torch.argmin(dists, dim=0).view(sigma.shape) | 88 | return torch.argmin(dists, dim=0).view(sigma.shape) |
| @@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | |||
| 96 | return t.view(sigma.shape) | 94 | return t.view(sigma.shape) |
| 97 | 95 | ||
| 98 | 96 | ||
| 99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | 97 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
| 100 | sigma = sigma.to(Unet.device) | 98 | sigma = sigma.to(Unet.device) |
| 101 | DSsigmas = DSsigmas.to(Unet.device) | 99 | DSsigmas = DSsigmas.to(Unet.device) |
| 102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 100 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
| 103 | # ??? what is eps? | 101 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
| 104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
| 105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
| 106 | encoder_hidden_states=kwargs['cond']).sample | 102 | encoder_hidden_states=kwargs['cond']).sample |
| 107 | return input + eps * c_out | 103 | return input + eps * c_out |
| 108 | 104 | ||
