diff options
author | Volpeon <git@volpeon.ink> | 2022-10-01 16:53:19 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-01 16:53:19 +0200 |
commit | 6720c99f7082dc855059ad4afd6b3cb45b62bc1f (patch) | |
tree | d27f69880472df0cd6f63ea42bbf7a789ec5d0b7 | |
parent | Made inference script interactive (diff) | |
download | textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.gz textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.bz2 textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.zip |
Fix seed, better progress bar, fix euler_a for batch size > 1
-rw-r--r-- | infer.py | 10 | ||||
-rw-r--r-- | pipelines/stable_diffusion/clip_guided_stable_diffusion.py | 4 | ||||
-rw-r--r-- | schedulers/scheduling_euler_a.py | 16 |
3 files changed, 15 insertions, 15 deletions
@@ -91,7 +91,7 @@ def create_cmd_parser(): | |||
91 | parser.add_argument( | 91 | parser.add_argument( |
92 | "--seed", | 92 | "--seed", |
93 | type=int, | 93 | type=int, |
94 | default=torch.random.seed(), | 94 | default=None, |
95 | ) | 95 | ) |
96 | parser.add_argument( | 96 | parser.add_argument( |
97 | "--config", | 97 | "--config", |
@@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args): | |||
167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") | 167 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") |
168 | output_dir.mkdir(parents=True, exist_ok=True) | 168 | output_dir.mkdir(parents=True, exist_ok=True) |
169 | 169 | ||
170 | seed = args.seed or torch.random.seed() | ||
171 | |||
170 | save_args(output_dir, args) | 172 | save_args(output_dir, args) |
171 | 173 | ||
172 | with autocast("cuda"): | 174 | with autocast("cuda"): |
173 | for i in range(args.batch_num): | 175 | for i in range(args.batch_num): |
174 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 176 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") |
177 | |||
178 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | ||
175 | images = pipeline( | 179 | images = pipeline( |
176 | prompt=[args.prompt] * args.batch_size, | 180 | prompt=[args.prompt] * args.batch_size, |
177 | height=args.height, | 181 | height=args.height, |
@@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args): | |||
183 | ).images | 187 | ).images |
184 | 188 | ||
185 | for j, image in enumerate(images): | 189 | for j, image in enumerate(images): |
186 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 190 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) |
187 | 191 | ||
188 | 192 | ||
189 | class CmdParse(cmd.Cmd): | 193 | class CmdParse(cmd.Cmd): |
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py index ddf7ce1..eff74b5 100644 --- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py +++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py | |||
@@ -254,10 +254,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): | |||
254 | noise_pred = None | 254 | noise_pred = None |
255 | if isinstance(self.scheduler, EulerAScheduler): | 255 | if isinstance(self.scheduler, EulerAScheduler): |
256 | sigma = t.reshape(1) | 256 | sigma = t.reshape(1) |
257 | sigma_in = torch.cat([sigma] * 2) | 257 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) |
258 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) | 258 | # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) |
259 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 259 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, |
260 | text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) | 260 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) |
261 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample | 261 | # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample |
262 | else: | 262 | else: |
263 | # predict the noise residual | 263 | # predict the noise residual |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 29ebd07..9fbedaa 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
@@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None): | |||
41 | return append_zero(sigmas.flip(0)) | 41 | return append_zero(sigmas.flip(0)) |
42 | t_max = len(sigmas) - 1 # = 999 | 42 | t_max = len(sigmas) - 1 # = 999 |
43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) | 43 | t = torch.linspace(t_max, 0, n, device=sigmas.device) |
44 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
45 | return append_zero(t_to_sigma(t, sigmas)) | 44 | return append_zero(t_to_sigma(t, sigmas)) |
46 | 45 | ||
47 | # from k_samplers utils.py | 46 | # from k_samplers utils.py |
@@ -55,14 +54,15 @@ def append_dims(x, target_dims): | |||
55 | return x[(...,) + (None,) * dims_to_append] | 54 | return x[(...,) + (None,) * dims_to_append] |
56 | 55 | ||
57 | 56 | ||
58 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): | 57 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): |
59 | # x_in = torch.cat([x] * 2)#A# concat the latent | 58 | # x_in = torch.cat([x] * 2)#A# concat the latent |
60 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | 59 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma |
61 | # cond_in = torch.cat([uncond, cond]) | 60 | # cond_in = torch.cat([uncond, cond]) |
62 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | 61 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) |
63 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | 62 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) |
64 | # return uncond + (cond - uncond) * cond_scale | 63 | # return uncond + (cond - uncond) * cond_scale |
65 | noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) | 64 | noise_pred = DiscreteEpsDDPMDenoiserForward( |
65 | Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) | ||
66 | return noise_pred | 66 | return noise_pred |
67 | 67 | ||
68 | # from k_samplers sampling.py | 68 | # from k_samplers sampling.py |
@@ -82,9 +82,7 @@ def get_scalings(sigma): | |||
82 | # DiscreteSchedule DS | 82 | # DiscreteSchedule DS |
83 | 83 | ||
84 | 84 | ||
85 | def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | 85 | def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): |
86 | # quantize = self.quantize if quantize is None else quantize | ||
87 | quantize = False | ||
88 | dists = torch.abs(sigma - DSsigmas[:, None]) | 86 | dists = torch.abs(sigma - DSsigmas[:, None]) |
89 | if quantize: | 87 | if quantize: |
90 | return torch.argmin(dists, dim=0).view(sigma.shape) | 88 | return torch.argmin(dists, dim=0).view(sigma.shape) |
@@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): | |||
96 | return t.view(sigma.shape) | 94 | return t.view(sigma.shape) |
97 | 95 | ||
98 | 96 | ||
99 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): | 97 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): |
100 | sigma = sigma.to(Unet.device) | 98 | sigma = sigma.to(Unet.device) |
101 | DSsigmas = DSsigmas.to(Unet.device) | 99 | DSsigmas = DSsigmas.to(Unet.device) |
102 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | 100 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] |
103 | # ??? what is eps? | 101 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), |
104 | # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs) | ||
105 | eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas), | ||
106 | encoder_hidden_states=kwargs['cond']).sample | 102 | encoder_hidden_states=kwargs['cond']).sample |
107 | return input + eps * c_out | 103 | return input + eps * c_out |
108 | 104 | ||