summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-01 16:53:19 +0200
committerVolpeon <git@volpeon.ink>2022-10-01 16:53:19 +0200
commit6720c99f7082dc855059ad4afd6b3cb45b62bc1f (patch)
treed27f69880472df0cd6f63ea42bbf7a789ec5d0b7
parentMade inference script interactive (diff)
downloadtextual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.gz
textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.tar.bz2
textual-inversion-diff-6720c99f7082dc855059ad4afd6b3cb45b62bc1f.zip
Fix seed, better progress bar, fix euler_a for batch size > 1
-rw-r--r--infer.py10
-rw-r--r--pipelines/stable_diffusion/clip_guided_stable_diffusion.py4
-rw-r--r--schedulers/scheduling_euler_a.py16
3 files changed, 15 insertions, 15 deletions
diff --git a/infer.py b/infer.py
index 40720ea..d917239 100644
--- a/infer.py
+++ b/infer.py
@@ -91,7 +91,7 @@ def create_cmd_parser():
91 parser.add_argument( 91 parser.add_argument(
92 "--seed", 92 "--seed",
93 type=int, 93 type=int,
94 default=torch.random.seed(), 94 default=None,
95 ) 95 )
96 parser.add_argument( 96 parser.add_argument(
97 "--config", 97 "--config",
@@ -167,11 +167,15 @@ def generate(output_dir, pipeline, args):
167 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}") 167 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt)[:100]}")
168 output_dir.mkdir(parents=True, exist_ok=True) 168 output_dir.mkdir(parents=True, exist_ok=True)
169 169
170 seed = args.seed or torch.random.seed()
171
170 save_args(output_dir, args) 172 save_args(output_dir, args)
171 173
172 with autocast("cuda"): 174 with autocast("cuda"):
173 for i in range(args.batch_num): 175 for i in range(args.batch_num):
174 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 176 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}")
177
178 generator = torch.Generator(device="cuda").manual_seed(seed + i)
175 images = pipeline( 179 images = pipeline(
176 prompt=[args.prompt] * args.batch_size, 180 prompt=[args.prompt] * args.batch_size,
177 height=args.height, 181 height=args.height,
@@ -183,7 +187,7 @@ def generate(output_dir, pipeline, args):
183 ).images 187 ).images
184 188
185 for j, image in enumerate(images): 189 for j, image in enumerate(images):
186 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) 190 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg"))
187 191
188 192
189class CmdParse(cmd.Cmd): 193class CmdParse(cmd.Cmd):
diff --git a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
index ddf7ce1..eff74b5 100644
--- a/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
+++ b/pipelines/stable_diffusion/clip_guided_stable_diffusion.py
@@ -254,10 +254,10 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
254 noise_pred = None 254 noise_pred = None
255 if isinstance(self.scheduler, EulerAScheduler): 255 if isinstance(self.scheduler, EulerAScheduler):
256 sigma = t.reshape(1) 256 sigma = t.reshape(1)
257 sigma_in = torch.cat([sigma] * 2) 257 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
258 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale) 258 # noise_pred = model(latent_model_input,sigma_in,uncond_embeddings, text_embeddings,guidance_scale)
259 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 259 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
260 text_embeddings, guidance_scale, DSsigmas=self.scheduler.DSsigmas) 260 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
261 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample 261 # noise_pred = self.unet(latent_model_input, sigma_in, encoder_hidden_states=text_embeddings).sample
262 else: 262 else:
263 # predict the noise residual 263 # predict the noise residual
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 29ebd07..9fbedaa 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -41,7 +41,6 @@ def get_sigmas(sigmas, n=None):
41 return append_zero(sigmas.flip(0)) 41 return append_zero(sigmas.flip(0))
42 t_max = len(sigmas) - 1 # = 999 42 t_max = len(sigmas) - 1 # = 999
43 t = torch.linspace(t_max, 0, n, device=sigmas.device) 43 t = torch.linspace(t_max, 0, n, device=sigmas.device)
44 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
45 return append_zero(t_to_sigma(t, sigmas)) 44 return append_zero(t_to_sigma(t, sigmas))
46 45
47# from k_samplers utils.py 46# from k_samplers utils.py
@@ -55,14 +54,15 @@ def append_dims(x, target_dims):
55 return x[(...,) + (None,) * dims_to_append] 54 return x[(...,) + (None,) * dims_to_append]
56 55
57 56
58def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, DSsigmas=None): 57def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None):
59 # x_in = torch.cat([x] * 2)#A# concat the latent 58 # x_in = torch.cat([x] * 2)#A# concat the latent
60 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma 59 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
61 # cond_in = torch.cat([uncond, cond]) 60 # cond_in = torch.cat([uncond, cond])
62 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) 61 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
63 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) 62 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
64 # return uncond + (cond - uncond) * cond_scale 63 # return uncond + (cond - uncond) * cond_scale
65 noise_pred = DiscreteEpsDDPMDenoiserForward(Unet, x_in, sigma_in, DSsigmas=DSsigmas, cond=cond_in) 64 noise_pred = DiscreteEpsDDPMDenoiserForward(
65 Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in)
66 return noise_pred 66 return noise_pred
67 67
68# from k_samplers sampling.py 68# from k_samplers sampling.py
@@ -82,9 +82,7 @@ def get_scalings(sigma):
82# DiscreteSchedule DS 82# DiscreteSchedule DS
83 83
84 84
85def DSsigma_to_t(sigma, quantize=None, DSsigmas=None): 85def DSsigma_to_t(sigma, quantize=False, DSsigmas=None):
86 # quantize = self.quantize if quantize is None else quantize
87 quantize = False
88 dists = torch.abs(sigma - DSsigmas[:, None]) 86 dists = torch.abs(sigma - DSsigmas[:, None])
89 if quantize: 87 if quantize:
90 return torch.argmin(dists, dim=0).view(sigma.shape) 88 return torch.argmin(dists, dim=0).view(sigma.shape)
@@ -96,13 +94,11 @@ def DSsigma_to_t(sigma, quantize=None, DSsigmas=None):
96 return t.view(sigma.shape) 94 return t.view(sigma.shape)
97 95
98 96
99def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, **kwargs): 97def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs):
100 sigma = sigma.to(Unet.device) 98 sigma = sigma.to(Unet.device)
101 DSsigmas = DSsigmas.to(Unet.device) 99 DSsigmas = DSsigmas.to(Unet.device)
102 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] 100 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
103 # ??? what is eps? 101 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas),
104 # eps = CVDget_eps(Unet,input * c_in, DSsigma_to_t(sigma), **kwargs)
105 eps = Unet(input * c_in, DSsigma_to_t(sigma, DSsigmas=DSsigmas),
106 encoder_hidden_states=kwargs['cond']).sample 102 encoder_hidden_states=kwargs['cond']).sample
107 return input + eps * c_out 103 return input + eps * c_out
108 104