summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py48
-rw-r--r--train_dreambooth.py12
-rw-r--r--train_ti.py8
-rw-r--r--training/util.py106
4 files changed, 92 insertions, 82 deletions
diff --git a/infer.py b/infer.py
index 420cb83..f566114 100644
--- a/infer.py
+++ b/infer.py
@@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype):
209 return pipeline 209 return pipeline
210 210
211 211
212@torch.inference_mode()
212def generate(output_dir, pipeline, args): 213def generate(output_dir, pipeline, args):
213 if isinstance(args.prompt, str): 214 if isinstance(args.prompt, str):
214 args.prompt = [args.prompt] 215 args.prompt = [args.prompt]
@@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args):
245 elif args.scheduler == "kdpm2_a": 246 elif args.scheduler == "kdpm2_a":
246 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) 247 pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
247 248
248 with torch.autocast("cuda"), torch.inference_mode(): 249 for i in range(args.batch_num):
249 for i in range(args.batch_num): 250 pipeline.set_progress_bar_config(
250 pipeline.set_progress_bar_config( 251 desc=f"Batch {i + 1} of {args.batch_num}",
251 desc=f"Batch {i + 1} of {args.batch_num}", 252 dynamic_ncols=True
252 dynamic_ncols=True 253 )
253 ) 254
254 255 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
255 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 256 images = pipeline(
256 images = pipeline( 257 prompt=args.prompt,
257 prompt=args.prompt, 258 negative_prompt=args.negative_prompt,
258 negative_prompt=args.negative_prompt, 259 height=args.height,
259 height=args.height, 260 width=args.width,
260 width=args.width, 261 num_images_per_prompt=args.batch_size,
261 num_images_per_prompt=args.batch_size, 262 num_inference_steps=args.steps,
262 num_inference_steps=args.steps, 263 guidance_scale=args.guidance_scale,
263 guidance_scale=args.guidance_scale, 264 generator=generator,
264 generator=generator, 265 image=init_image,
265 image=init_image, 266 strength=args.image_noise,
266 strength=args.image_noise, 267 ).images
267 ).images 268
268 269 for j, image in enumerate(images):
269 for j, image in enumerate(images): 270 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png"))
270 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) 271 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85)
271 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85)
272 272
273 if torch.cuda.is_available(): 273 if torch.cuda.is_available():
274 torch.cuda.empty_cache() 274 torch.cuda.empty_cache()
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e239833..2c765ec 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -389,6 +389,7 @@ def parse_args():
389class Checkpointer(CheckpointerBase): 389class Checkpointer(CheckpointerBase):
390 def __init__( 390 def __init__(
391 self, 391 self,
392 weight_dtype,
392 datamodule, 393 datamodule,
393 accelerator, 394 accelerator,
394 vae, 395 vae,
@@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase):
416 sample_batch_size=sample_batch_size 417 sample_batch_size=sample_batch_size
417 ) 418 )
418 419
420 self.weight_dtype = weight_dtype
419 self.accelerator = accelerator 421 self.accelerator = accelerator
420 self.vae = vae 422 self.vae = vae
421 self.unet = unet 423 self.unet = unet
@@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase):
452 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) 454 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
453 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 455 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
454 456
457 orig_unet_dtype = unet.dtype
458 orig_text_encoder_dtype = text_encoder.dtype
459
460 unet.to(dtype=self.weight_dtype)
461 text_encoder.to(dtype=self.weight_dtype)
462
455 pipeline = VlpnStableDiffusion( 463 pipeline = VlpnStableDiffusion(
456 text_encoder=text_encoder, 464 text_encoder=text_encoder,
457 vae=self.vae, 465 vae=self.vae,
@@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase):
463 471
464 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 472 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
465 473
474 unet.to(dtype=orig_unet_dtype)
475 text_encoder.to(dtype=orig_text_encoder_dtype)
476
466 del unet 477 del unet
467 del text_encoder 478 del text_encoder
468 del pipeline 479 del pipeline
@@ -798,6 +809,7 @@ def main():
798 max_acc_val = 0.0 809 max_acc_val = 0.0
799 810
800 checkpointer = Checkpointer( 811 checkpointer = Checkpointer(
812 weight_dtype=weight_dtype,
801 datamodule=datamodule, 813 datamodule=datamodule,
802 accelerator=accelerator, 814 accelerator=accelerator,
803 vae=vae, 815 vae=vae,
diff --git a/train_ti.py b/train_ti.py
index 5f37d54..a228795 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -361,6 +361,7 @@ def parse_args():
361class Checkpointer(CheckpointerBase): 361class Checkpointer(CheckpointerBase):
362 def __init__( 362 def __init__(
363 self, 363 self,
364 weight_dtype,
364 datamodule, 365 datamodule,
365 accelerator, 366 accelerator,
366 vae, 367 vae,
@@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase):
387 sample_batch_size=sample_batch_size 388 sample_batch_size=sample_batch_size
388 ) 389 )
389 390
391 self.weight_dtype = weight_dtype
390 self.accelerator = accelerator 392 self.accelerator = accelerator
391 self.vae = vae 393 self.vae = vae
392 self.unet = unet 394 self.unet = unet
@@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase):
417 @torch.no_grad() 419 @torch.no_grad()
418 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 420 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
419 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 421 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
422 orig_dtype = text_encoder.dtype
423 text_encoder.to(dtype=self.weight_dtype)
420 424
421 # Save a sample image
422 pipeline = VlpnStableDiffusion( 425 pipeline = VlpnStableDiffusion(
423 text_encoder=text_encoder, 426 text_encoder=text_encoder,
424 vae=self.vae, 427 vae=self.vae,
@@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase):
430 433
431 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 434 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
432 435
436 text_encoder.to(dtype=orig_dtype)
437
433 del text_encoder 438 del text_encoder
434 del pipeline 439 del pipeline
435 440
@@ -739,6 +744,7 @@ def main():
739 max_acc_val = 0.0 744 max_acc_val = 0.0
740 745
741 checkpointer = Checkpointer( 746 checkpointer = Checkpointer(
747 weight_dtype=weight_dtype,
742 datamodule=datamodule, 748 datamodule=datamodule,
743 accelerator=accelerator, 749 accelerator=accelerator,
744 vae=vae, 750 vae=vae,
diff --git a/training/util.py b/training/util.py
index 5c056a6..a0c15cd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -60,7 +60,7 @@ class CheckpointerBase:
60 self.sample_batches = sample_batches 60 self.sample_batches = sample_batches
61 self.sample_batch_size = sample_batch_size 61 self.sample_batch_size = sample_batch_size
62 62
63 @torch.no_grad() 63 @torch.inference_mode()
64 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 64 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
65 samples_path = Path(self.output_dir).joinpath("samples") 65 samples_path = Path(self.output_dir).joinpath("samples")
66 66
@@ -68,65 +68,57 @@ class CheckpointerBase:
68 val_data = self.datamodule.val_dataloader() 68 val_data = self.datamodule.val_dataloader()
69 69
70 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 70 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
71 stable_latents = torch.randn(
72 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
73 device=pipeline.device,
74 generator=generator,
75 )
76 71
77 grid_cols = min(self.sample_batch_size, 4) 72 grid_cols = min(self.sample_batch_size, 4)
78 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols 73 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
79 74
80 with torch.autocast("cuda"), torch.inference_mode(): 75 for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]:
81 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 76 all_samples = []
82 all_samples = [] 77 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 78 file_path.parent.mkdir(parents=True, exist_ok=True)
84 file_path.parent.mkdir(parents=True, exist_ok=True) 79
85 80 data_enum = enumerate(data)
86 data_enum = enumerate(data) 81
87 82 batches = [
88 batches = [ 83 batch
89 batch 84 for j, batch in data_enum
90 for j, batch in data_enum 85 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 if j * data.batch_size < self.sample_batch_size * self.sample_batches 86 ]
92 ] 87 prompts = [
93 prompts = [ 88 prompt
94 prompt 89 for batch in batches
95 for batch in batches 90 for prompt in batch["prompts"]
96 for prompt in batch["prompts"] 91 ]
97 ] 92 nprompts = [
98 nprompts = [ 93 prompt
99 prompt 94 for batch in batches
100 for batch in batches 95 for prompt in batch["nprompts"]
101 for prompt in batch["nprompts"] 96 ]
102 ] 97
103 98 for i in range(self.sample_batches):
104 for i in range(self.sample_batches): 99 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
105 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 100 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
106 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 101
107 102 samples = pipeline(
108 samples = pipeline( 103 prompt=prompt,
109 prompt=prompt, 104 negative_prompt=nprompt,
110 negative_prompt=nprompt, 105 height=self.sample_image_size,
111 height=self.sample_image_size, 106 width=self.sample_image_size,
112 width=self.sample_image_size, 107 generator=gen,
113 image=latents[:len(prompt)] if latents is not None else None, 108 guidance_scale=guidance_scale,
114 generator=generator if latents is not None else None, 109 eta=eta,
115 guidance_scale=guidance_scale, 110 num_inference_steps=num_inference_steps,
116 eta=eta, 111 output_type='pil'
117 num_inference_steps=num_inference_steps, 112 ).images
118 output_type='pil' 113
119 ).images 114 all_samples += samples
120 115
121 all_samples += samples 116 del samples
122 117
123 del samples 118 image_grid = make_grid(all_samples, grid_rows, grid_cols)
124 119 image_grid.save(file_path, quality=85)
125 image_grid = make_grid(all_samples, grid_rows, grid_cols) 120
126 image_grid.save(file_path, quality=85) 121 del all_samples
127 122 del image_grid
128 del all_samples
129 del image_grid
130 123
131 del generator 124 del generator
132 del stable_latents