diff options
| -rw-r--r-- | infer.py | 44 | ||||
| -rw-r--r-- | train_dreambooth.py | 12 | ||||
| -rw-r--r-- | train_ti.py | 8 | ||||
| -rw-r--r-- | training/util.py | 90 |
4 files changed, 82 insertions, 72 deletions
| @@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
| 209 | return pipeline | 209 | return pipeline |
| 210 | 210 | ||
| 211 | 211 | ||
| 212 | @torch.inference_mode() | ||
| 212 | def generate(output_dir, pipeline, args): | 213 | def generate(output_dir, pipeline, args): |
| 213 | if isinstance(args.prompt, str): | 214 | if isinstance(args.prompt, str): |
| 214 | args.prompt = [args.prompt] | 215 | args.prompt = [args.prompt] |
| @@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args): | |||
| 245 | elif args.scheduler == "kdpm2_a": | 246 | elif args.scheduler == "kdpm2_a": |
| 246 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 247 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
| 247 | 248 | ||
| 248 | with torch.autocast("cuda"), torch.inference_mode(): | 249 | for i in range(args.batch_num): |
| 249 | for i in range(args.batch_num): | 250 | pipeline.set_progress_bar_config( |
| 250 | pipeline.set_progress_bar_config( | 251 | desc=f"Batch {i + 1} of {args.batch_num}", |
| 251 | desc=f"Batch {i + 1} of {args.batch_num}", | 252 | dynamic_ncols=True |
| 252 | dynamic_ncols=True | 253 | ) |
| 253 | ) | ||
| 254 | 254 | ||
| 255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
| 256 | images = pipeline( | 256 | images = pipeline( |
| 257 | prompt=args.prompt, | 257 | prompt=args.prompt, |
| 258 | negative_prompt=args.negative_prompt, | 258 | negative_prompt=args.negative_prompt, |
| 259 | height=args.height, | 259 | height=args.height, |
| 260 | width=args.width, | 260 | width=args.width, |
| 261 | num_images_per_prompt=args.batch_size, | 261 | num_images_per_prompt=args.batch_size, |
| 262 | num_inference_steps=args.steps, | 262 | num_inference_steps=args.steps, |
| 263 | guidance_scale=args.guidance_scale, | 263 | guidance_scale=args.guidance_scale, |
| 264 | generator=generator, | 264 | generator=generator, |
| 265 | image=init_image, | 265 | image=init_image, |
| 266 | strength=args.image_noise, | 266 | strength=args.image_noise, |
| 267 | ).images | 267 | ).images |
| 268 | 268 | ||
| 269 | for j, image in enumerate(images): | 269 | for j, image in enumerate(images): |
| 270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) |
| 271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | 271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) |
| 272 | 272 | ||
| 273 | if torch.cuda.is_available(): | 273 | if torch.cuda.is_available(): |
| 274 | torch.cuda.empty_cache() | 274 | torch.cuda.empty_cache() |
diff --git a/train_dreambooth.py b/train_dreambooth.py index e239833..2c765ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -389,6 +389,7 @@ def parse_args(): | |||
| 389 | class Checkpointer(CheckpointerBase): | 389 | class Checkpointer(CheckpointerBase): |
| 390 | def __init__( | 390 | def __init__( |
| 391 | self, | 391 | self, |
| 392 | weight_dtype, | ||
| 392 | datamodule, | 393 | datamodule, |
| 393 | accelerator, | 394 | accelerator, |
| 394 | vae, | 395 | vae, |
| @@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase): | |||
| 416 | sample_batch_size=sample_batch_size | 417 | sample_batch_size=sample_batch_size |
| 417 | ) | 418 | ) |
| 418 | 419 | ||
| 420 | self.weight_dtype = weight_dtype | ||
| 419 | self.accelerator = accelerator | 421 | self.accelerator = accelerator |
| 420 | self.vae = vae | 422 | self.vae = vae |
| 421 | self.unet = unet | 423 | self.unet = unet |
| @@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase): | |||
| 452 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 454 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
| 453 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 455 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 454 | 456 | ||
| 457 | orig_unet_dtype = unet.dtype | ||
| 458 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 459 | |||
| 460 | unet.to(dtype=self.weight_dtype) | ||
| 461 | text_encoder.to(dtype=self.weight_dtype) | ||
| 462 | |||
| 455 | pipeline = VlpnStableDiffusion( | 463 | pipeline = VlpnStableDiffusion( |
| 456 | text_encoder=text_encoder, | 464 | text_encoder=text_encoder, |
| 457 | vae=self.vae, | 465 | vae=self.vae, |
| @@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase): | |||
| 463 | 471 | ||
| 464 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 472 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 465 | 473 | ||
| 474 | unet.to(dtype=orig_unet_dtype) | ||
| 475 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
| 476 | |||
| 466 | del unet | 477 | del unet |
| 467 | del text_encoder | 478 | del text_encoder |
| 468 | del pipeline | 479 | del pipeline |
| @@ -798,6 +809,7 @@ def main(): | |||
| 798 | max_acc_val = 0.0 | 809 | max_acc_val = 0.0 |
| 799 | 810 | ||
| 800 | checkpointer = Checkpointer( | 811 | checkpointer = Checkpointer( |
| 812 | weight_dtype=weight_dtype, | ||
| 801 | datamodule=datamodule, | 813 | datamodule=datamodule, |
| 802 | accelerator=accelerator, | 814 | accelerator=accelerator, |
| 803 | vae=vae, | 815 | vae=vae, |
diff --git a/train_ti.py b/train_ti.py index 5f37d54..a228795 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -361,6 +361,7 @@ def parse_args(): | |||
| 361 | class Checkpointer(CheckpointerBase): | 361 | class Checkpointer(CheckpointerBase): |
| 362 | def __init__( | 362 | def __init__( |
| 363 | self, | 363 | self, |
| 364 | weight_dtype, | ||
| 364 | datamodule, | 365 | datamodule, |
| 365 | accelerator, | 366 | accelerator, |
| 366 | vae, | 367 | vae, |
| @@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase): | |||
| 387 | sample_batch_size=sample_batch_size | 388 | sample_batch_size=sample_batch_size |
| 388 | ) | 389 | ) |
| 389 | 390 | ||
| 391 | self.weight_dtype = weight_dtype | ||
| 390 | self.accelerator = accelerator | 392 | self.accelerator = accelerator |
| 391 | self.vae = vae | 393 | self.vae = vae |
| 392 | self.unet = unet | 394 | self.unet = unet |
| @@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase): | |||
| 417 | @torch.no_grad() | 419 | @torch.no_grad() |
| 418 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 420 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 419 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 421 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 422 | orig_dtype = text_encoder.dtype | ||
| 423 | text_encoder.to(dtype=self.weight_dtype) | ||
| 420 | 424 | ||
| 421 | # Save a sample image | ||
| 422 | pipeline = VlpnStableDiffusion( | 425 | pipeline = VlpnStableDiffusion( |
| 423 | text_encoder=text_encoder, | 426 | text_encoder=text_encoder, |
| 424 | vae=self.vae, | 427 | vae=self.vae, |
| @@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase): | |||
| 430 | 433 | ||
| 431 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 434 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 432 | 435 | ||
| 436 | text_encoder.to(dtype=orig_dtype) | ||
| 437 | |||
| 433 | del text_encoder | 438 | del text_encoder |
| 434 | del pipeline | 439 | del pipeline |
| 435 | 440 | ||
| @@ -739,6 +744,7 @@ def main(): | |||
| 739 | max_acc_val = 0.0 | 744 | max_acc_val = 0.0 |
| 740 | 745 | ||
| 741 | checkpointer = Checkpointer( | 746 | checkpointer = Checkpointer( |
| 747 | weight_dtype=weight_dtype, | ||
| 742 | datamodule=datamodule, | 748 | datamodule=datamodule, |
| 743 | accelerator=accelerator, | 749 | accelerator=accelerator, |
| 744 | vae=vae, | 750 | vae=vae, |
diff --git a/training/util.py b/training/util.py index 5c056a6..a0c15cd 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -60,7 +60,7 @@ class CheckpointerBase: | |||
| 60 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
| 61 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
| 62 | 62 | ||
| 63 | @torch.no_grad() | 63 | @torch.inference_mode() |
| 64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 65 | samples_path = Path(self.output_dir).joinpath("samples") | 65 | samples_path = Path(self.output_dir).joinpath("samples") |
| 66 | 66 | ||
| @@ -68,65 +68,57 @@ class CheckpointerBase: | |||
| 68 | val_data = self.datamodule.val_dataloader() | 68 | val_data = self.datamodule.val_dataloader() |
| 69 | 69 | ||
| 70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 71 | stable_latents = torch.randn( | ||
| 72 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
| 73 | device=pipeline.device, | ||
| 74 | generator=generator, | ||
| 75 | ) | ||
| 76 | 71 | ||
| 77 | grid_cols = min(self.sample_batch_size, 4) | 72 | grid_cols = min(self.sample_batch_size, 4) |
| 78 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | 73 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols |
| 79 | 74 | ||
| 80 | with torch.autocast("cuda"), torch.inference_mode(): | 75 | for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: |
| 81 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 76 | all_samples = [] |
| 82 | all_samples = [] | 77 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 78 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 84 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 85 | 79 | ||
| 86 | data_enum = enumerate(data) | 80 | data_enum = enumerate(data) |
| 87 | 81 | ||
| 88 | batches = [ | 82 | batches = [ |
| 89 | batch | 83 | batch |
| 90 | for j, batch in data_enum | 84 | for j, batch in data_enum |
| 91 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | 85 | if j * data.batch_size < self.sample_batch_size * self.sample_batches |
| 92 | ] | 86 | ] |
| 93 | prompts = [ | 87 | prompts = [ |
| 94 | prompt | 88 | prompt |
| 95 | for batch in batches | 89 | for batch in batches |
| 96 | for prompt in batch["prompts"] | 90 | for prompt in batch["prompts"] |
| 97 | ] | 91 | ] |
| 98 | nprompts = [ | 92 | nprompts = [ |
| 99 | prompt | 93 | prompt |
| 100 | for batch in batches | 94 | for batch in batches |
| 101 | for prompt in batch["nprompts"] | 95 | for prompt in batch["nprompts"] |
| 102 | ] | 96 | ] |
| 103 | 97 | ||
| 104 | for i in range(self.sample_batches): | 98 | for i in range(self.sample_batches): |
| 105 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 99 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 106 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 100 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 107 | 101 | ||
| 108 | samples = pipeline( | 102 | samples = pipeline( |
| 109 | prompt=prompt, | 103 | prompt=prompt, |
| 110 | negative_prompt=nprompt, | 104 | negative_prompt=nprompt, |
| 111 | height=self.sample_image_size, | 105 | height=self.sample_image_size, |
| 112 | width=self.sample_image_size, | 106 | width=self.sample_image_size, |
| 113 | image=latents[:len(prompt)] if latents is not None else None, | 107 | generator=gen, |
| 114 | generator=generator if latents is not None else None, | 108 | guidance_scale=guidance_scale, |
| 115 | guidance_scale=guidance_scale, | 109 | eta=eta, |
| 116 | eta=eta, | 110 | num_inference_steps=num_inference_steps, |
| 117 | num_inference_steps=num_inference_steps, | 111 | output_type='pil' |
| 118 | output_type='pil' | 112 | ).images |
| 119 | ).images | ||
| 120 | 113 | ||
| 121 | all_samples += samples | 114 | all_samples += samples |
| 122 | 115 | ||
| 123 | del samples | 116 | del samples |
| 124 | 117 | ||
| 125 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 118 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
| 126 | image_grid.save(file_path, quality=85) | 119 | image_grid.save(file_path, quality=85) |
| 127 | 120 | ||
| 128 | del all_samples | 121 | del all_samples |
| 129 | del image_grid | 122 | del image_grid |
| 130 | 123 | ||
| 131 | del generator | 124 | del generator |
| 132 | del stable_latents | ||
