summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py95
-rw-r--r--train_dreambooth.py1
-rw-r--r--train_ti.py1
-rw-r--r--training/functional.py12
4 files changed, 41 insertions, 68 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index cb09fe1..c4f7401 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -293,53 +293,39 @@ class VlpnStableDiffusion(DiffusionPipeline):
293 293
294 return prompt_embeds 294 return prompt_embeds
295 295
296 def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): 296 def get_timesteps(self, num_inference_steps, strength, device):
297 if latents_are_image: 297 # get the original timestep using init_timestep
298 # get the original timestep using init_timestep 298 offset = self.scheduler.config.get("steps_offset", 0)
299 offset = self.scheduler.config.get("steps_offset", 0) 299 init_timestep = int(num_inference_steps * strength) + offset
300 init_timestep = int(num_inference_steps * strength) + offset 300 init_timestep = min(init_timestep, num_inference_steps)
301 init_timestep = min(init_timestep, num_inference_steps) 301
302 302 t_start = max(num_inference_steps - init_timestep + offset, 0)
303 t_start = max(num_inference_steps - init_timestep + offset, 0) 303 timesteps = self.scheduler.timesteps[t_start:]
304 timesteps = self.scheduler.timesteps[t_start:]
305 else:
306 timesteps = self.scheduler.timesteps
307 304
308 timesteps = timesteps.to(device) 305 timesteps = timesteps.to(device)
309 306
310 return timesteps 307 return timesteps
311 308
312 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None):
313 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 310 return torch.randn(
314 311 (batch_size, 1, 1, 1),
315 if isinstance(generator, list) and len(generator) != batch_size: 312 dtype=dtype,
316 raise ValueError( 313 device=device,
317 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 314 generator=generator
318 f" size of {batch_size}. Make sure the batch size matches the length of the generators." 315 ).expand(batch_size, 3, width, height)
319 )
320 316
321 if latents is None: 317 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None):
322 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
323 else:
324 latents = latents.to(device=device, dtype=dtype)
325
326 # scale the initial noise by the standard deviation required by the scheduler
327 latents = latents * self.scheduler.init_noise_sigma
328
329 return latents
330
331 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None):
332 init_image = init_image.to(device=device, dtype=dtype) 318 init_image = init_image.to(device=device, dtype=dtype)
333 init_latent_dist = self.vae.encode(init_image).latent_dist 319 init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
334 init_latents = init_latent_dist.sample(generator=generator) 320 init_latents = self.vae.config.scaling_factor * init_latents
335 init_latents = 0.18215 * init_latents
336 321
337 if batch_size > init_latents.shape[0]: 322 if batch_size % init_latents.shape[0] != 0:
338 raise ValueError( 323 raise ValueError(
339 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 324 f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
340 ) 325 )
341 else: 326 else:
342 init_latents = torch.cat([init_latents] * batch_size, dim=0) 327 batch_multiplier = batch_size // init_latents.shape[0]
328 init_latents = torch.cat([init_latents] * batch_multiplier, dim=0)
343 329
344 # add noise to latents using the timesteps 330 # add noise to latents using the timesteps
345 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) 331 noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
@@ -368,7 +354,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
368 return extra_step_kwargs 354 return extra_step_kwargs
369 355
370 def decode_latents(self, latents): 356 def decode_latents(self, latents):
371 latents = 1 / 0.18215 * latents 357 latents = 1 / self.vae.config.scaling_factor * latents
372 image = self.vae.decode(latents).sample 358 image = self.vae.decode(latents).sample
373 image = (image / 2 + 0.5).clamp(0, 1) 359 image = (image / 2 + 0.5).clamp(0, 1)
374 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 360 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -381,7 +367,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
381 prompt: Union[str, List[str], List[int], List[List[int]]], 367 prompt: Union[str, List[str], List[int], List[List[int]]],
382 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, 368 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None,
383 num_images_per_prompt: int = 1, 369 num_images_per_prompt: int = 1,
384 strength: float = 0.8, 370 strength: float = 1.0,
385 height: Optional[int] = None, 371 height: Optional[int] = None,
386 width: Optional[int] = None, 372 width: Optional[int] = None,
387 num_inference_steps: int = 50, 373 num_inference_steps: int = 50,
@@ -461,7 +447,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
461 device = self.execution_device 447 device = self.execution_device
462 do_classifier_free_guidance = guidance_scale > 1.0 448 do_classifier_free_guidance = guidance_scale > 1.0
463 do_self_attention_guidance = sag_scale > 0.0 449 do_self_attention_guidance = sag_scale > 0.0
464 latents_are_image = isinstance(image, PIL.Image.Image)
465 450
466 # 3. Encode input prompt 451 # 3. Encode input prompt
467 prompt_embeds = self.encode_prompt( 452 prompt_embeds = self.encode_prompt(
@@ -474,33 +459,31 @@ class VlpnStableDiffusion(DiffusionPipeline):
474 459
475 # 4. Prepare timesteps 460 # 4. Prepare timesteps
476 self.scheduler.set_timesteps(num_inference_steps, device=device) 461 self.scheduler.set_timesteps(num_inference_steps, device=device)
477 timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) 462 timesteps = self.get_timesteps(num_inference_steps, strength, device)
478 463
479 # 5. Prepare latent variables 464 # 5. Prepare latent variables
480 num_channels_latents = self.unet.in_channels 465 if isinstance(image, PIL.Image.Image):
481 if latents_are_image:
482 image = preprocess(image) 466 image = preprocess(image)
483 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 467 elif image is None:
484 latents = self.prepare_latents_from_image( 468 image = self.prepare_image(
485 image,
486 latent_timestep,
487 batch_size * num_images_per_prompt, 469 batch_size * num_images_per_prompt,
488 prompt_embeds.dtype,
489 device,
490 generator
491 )
492 else:
493 latents = self.prepare_latents(
494 batch_size * num_images_per_prompt,
495 num_channels_latents,
496 height,
497 width, 470 width,
471 height,
498 prompt_embeds.dtype, 472 prompt_embeds.dtype,
499 device, 473 device,
500 generator, 474 generator
501 image,
502 ) 475 )
503 476
477 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
478 latents = self.prepare_latents(
479 image,
480 latent_timestep,
481 batch_size * num_images_per_prompt,
482 prompt_embeds.dtype,
483 device,
484 generator
485 )
486
504 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 487 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
505 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 488 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
506 489
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8571dff..9b91172 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -526,7 +526,6 @@ def main():
526 with_prior_preservation=args.num_class_images != 0, 526 with_prior_preservation=args.num_class_images != 0,
527 prior_loss_weight=args.prior_loss_weight, 527 prior_loss_weight=args.prior_loss_weight,
528 no_val=args.valid_set_size == 0, 528 no_val=args.valid_set_size == 0,
529 # noise_offset=0,
530 ) 529 )
531 530
532 checkpoint_output_dir = output_dir / "model" 531 checkpoint_output_dir = output_dir / "model"
diff --git a/train_ti.py b/train_ti.py
index bc9348d..c139cc0 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -607,7 +607,6 @@ def main():
607 with_prior_preservation=args.num_class_images != 0, 607 with_prior_preservation=args.num_class_images != 0,
608 prior_loss_weight=args.prior_loss_weight, 608 prior_loss_weight=args.prior_loss_weight,
609 no_val=args.valid_set_size == 0, 609 no_val=args.valid_set_size == 0,
610 noise_offset=0,
611 strategy=textual_inversion_strategy, 610 strategy=textual_inversion_strategy,
612 num_train_epochs=args.num_train_epochs, 611 num_train_epochs=args.num_train_epochs,
613 gradient_accumulation_steps=args.gradient_accumulation_steps, 612 gradient_accumulation_steps=args.gradient_accumulation_steps,
diff --git a/training/functional.py b/training/functional.py
index 36269f0..1c38635 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -253,7 +253,6 @@ def loss_step(
253 text_encoder: CLIPTextModel, 253 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 254 with_prior_preservation: bool,
255 prior_loss_weight: float, 255 prior_loss_weight: float,
256 noise_offset: float,
257 seed: int, 256 seed: int,
258 step: int, 257 step: int,
259 batch: dict[str, Any], 258 batch: dict[str, Any],
@@ -268,17 +267,12 @@ def loss_step(
268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 267 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
269 268
270 # Sample noise that we'll add to the latents 269 # Sample noise that we'll add to the latents
271 offsets = noise_offset * torch.randn( 270 noise = torch.randn(
272 latents.shape[0], 1, 1, 1, 271 latents.shape,
273 dtype=latents.dtype, 272 dtype=latents.dtype,
274 layout=latents.layout, 273 layout=latents.layout,
275 device=latents.device, 274 device=latents.device,
276 generator=generator 275 generator=generator
277 ).expand(latents.shape)
278 noise = torch.normal(
279 mean=offsets,
280 std=1,
281 generator=generator,
282 ) 276 )
283 277
284 # Sample a random timestep for each image 278 # Sample a random timestep for each image
@@ -565,7 +559,6 @@ def train(
565 global_step_offset: int = 0, 559 global_step_offset: int = 0,
566 with_prior_preservation: bool = False, 560 with_prior_preservation: bool = False,
567 prior_loss_weight: float = 1.0, 561 prior_loss_weight: float = 1.0,
568 noise_offset: float = 0.2,
569 **kwargs, 562 **kwargs,
570): 563):
571 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 564 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -600,7 +593,6 @@ def train(
600 text_encoder, 593 text_encoder,
601 with_prior_preservation, 594 with_prior_preservation,
602 prior_loss_weight, 595 prior_loss_weight,
603 noise_offset,
604 seed, 596 seed,
605 ) 597 )
606 598