diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/util.py | 90 |
1 files changed, 41 insertions, 49 deletions
diff --git a/training/util.py b/training/util.py index 5c056a6..a0c15cd 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -60,7 +60,7 @@ class CheckpointerBase: | |||
| 60 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
| 61 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
| 62 | 62 | ||
| 63 | @torch.no_grad() | 63 | @torch.inference_mode() |
| 64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 65 | samples_path = Path(self.output_dir).joinpath("samples") | 65 | samples_path = Path(self.output_dir).joinpath("samples") |
| 66 | 66 | ||
| @@ -68,65 +68,57 @@ class CheckpointerBase: | |||
| 68 | val_data = self.datamodule.val_dataloader() | 68 | val_data = self.datamodule.val_dataloader() |
| 69 | 69 | ||
| 70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 71 | stable_latents = torch.randn( | ||
| 72 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
| 73 | device=pipeline.device, | ||
| 74 | generator=generator, | ||
| 75 | ) | ||
| 76 | 71 | ||
| 77 | grid_cols = min(self.sample_batch_size, 4) | 72 | grid_cols = min(self.sample_batch_size, 4) |
| 78 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | 73 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols |
| 79 | 74 | ||
| 80 | with torch.autocast("cuda"), torch.inference_mode(): | 75 | for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: |
| 81 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 76 | all_samples = [] |
| 82 | all_samples = [] | 77 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 78 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 84 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 85 | 79 | ||
| 86 | data_enum = enumerate(data) | 80 | data_enum = enumerate(data) |
| 87 | 81 | ||
| 88 | batches = [ | 82 | batches = [ |
| 89 | batch | 83 | batch |
| 90 | for j, batch in data_enum | 84 | for j, batch in data_enum |
| 91 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | 85 | if j * data.batch_size < self.sample_batch_size * self.sample_batches |
| 92 | ] | 86 | ] |
| 93 | prompts = [ | 87 | prompts = [ |
| 94 | prompt | 88 | prompt |
| 95 | for batch in batches | 89 | for batch in batches |
| 96 | for prompt in batch["prompts"] | 90 | for prompt in batch["prompts"] |
| 97 | ] | 91 | ] |
| 98 | nprompts = [ | 92 | nprompts = [ |
| 99 | prompt | 93 | prompt |
| 100 | for batch in batches | 94 | for batch in batches |
| 101 | for prompt in batch["nprompts"] | 95 | for prompt in batch["nprompts"] |
| 102 | ] | 96 | ] |
| 103 | 97 | ||
| 104 | for i in range(self.sample_batches): | 98 | for i in range(self.sample_batches): |
| 105 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 99 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 106 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 100 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 107 | 101 | ||
| 108 | samples = pipeline( | 102 | samples = pipeline( |
| 109 | prompt=prompt, | 103 | prompt=prompt, |
| 110 | negative_prompt=nprompt, | 104 | negative_prompt=nprompt, |
| 111 | height=self.sample_image_size, | 105 | height=self.sample_image_size, |
| 112 | width=self.sample_image_size, | 106 | width=self.sample_image_size, |
| 113 | image=latents[:len(prompt)] if latents is not None else None, | 107 | generator=gen, |
| 114 | generator=generator if latents is not None else None, | 108 | guidance_scale=guidance_scale, |
| 115 | guidance_scale=guidance_scale, | 109 | eta=eta, |
| 116 | eta=eta, | 110 | num_inference_steps=num_inference_steps, |
| 117 | num_inference_steps=num_inference_steps, | 111 | output_type='pil' |
| 118 | output_type='pil' | 112 | ).images |
| 119 | ).images | ||
| 120 | 113 | ||
| 121 | all_samples += samples | 114 | all_samples += samples |
| 122 | 115 | ||
| 123 | del samples | 116 | del samples |
| 124 | 117 | ||
| 125 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 118 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
| 126 | image_grid.save(file_path, quality=85) | 119 | image_grid.save(file_path, quality=85) |
| 127 | 120 | ||
| 128 | del all_samples | 121 | del all_samples |
| 129 | del image_grid | 122 | del image_grid |
| 130 | 123 | ||
| 131 | del generator | 124 | del generator |
| 132 | del stable_latents | ||
