diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/util.py | 106 |
1 files changed, 49 insertions, 57 deletions
diff --git a/training/util.py b/training/util.py index 5c056a6..a0c15cd 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -60,7 +60,7 @@ class CheckpointerBase: | |||
60 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
61 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
62 | 62 | ||
63 | @torch.no_grad() | 63 | @torch.inference_mode() |
64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
65 | samples_path = Path(self.output_dir).joinpath("samples") | 65 | samples_path = Path(self.output_dir).joinpath("samples") |
66 | 66 | ||
@@ -68,65 +68,57 @@ class CheckpointerBase: | |||
68 | val_data = self.datamodule.val_dataloader() | 68 | val_data = self.datamodule.val_dataloader() |
69 | 69 | ||
70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
71 | stable_latents = torch.randn( | ||
72 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
73 | device=pipeline.device, | ||
74 | generator=generator, | ||
75 | ) | ||
76 | 71 | ||
77 | grid_cols = min(self.sample_batch_size, 4) | 72 | grid_cols = min(self.sample_batch_size, 4) |
78 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | 73 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols |
79 | 74 | ||
80 | with torch.autocast("cuda"), torch.inference_mode(): | 75 | for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: |
81 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 76 | all_samples = [] |
82 | all_samples = [] | 77 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 78 | file_path.parent.mkdir(parents=True, exist_ok=True) |
84 | file_path.parent.mkdir(parents=True, exist_ok=True) | 79 | |
85 | 80 | data_enum = enumerate(data) | |
86 | data_enum = enumerate(data) | 81 | |
87 | 82 | batches = [ | |
88 | batches = [ | 83 | batch |
89 | batch | 84 | for j, batch in data_enum |
90 | for j, batch in data_enum | 85 | if j * data.batch_size < self.sample_batch_size * self.sample_batches |
91 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | 86 | ] |
92 | ] | 87 | prompts = [ |
93 | prompts = [ | 88 | prompt |
94 | prompt | 89 | for batch in batches |
95 | for batch in batches | 90 | for prompt in batch["prompts"] |
96 | for prompt in batch["prompts"] | 91 | ] |
97 | ] | 92 | nprompts = [ |
98 | nprompts = [ | 93 | prompt |
99 | prompt | 94 | for batch in batches |
100 | for batch in batches | 95 | for prompt in batch["nprompts"] |
101 | for prompt in batch["nprompts"] | 96 | ] |
102 | ] | 97 | |
103 | 98 | for i in range(self.sample_batches): | |
104 | for i in range(self.sample_batches): | 99 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
105 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 100 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
106 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 101 | |
107 | 102 | samples = pipeline( | |
108 | samples = pipeline( | 103 | prompt=prompt, |
109 | prompt=prompt, | 104 | negative_prompt=nprompt, |
110 | negative_prompt=nprompt, | 105 | height=self.sample_image_size, |
111 | height=self.sample_image_size, | 106 | width=self.sample_image_size, |
112 | width=self.sample_image_size, | 107 | generator=gen, |
113 | image=latents[:len(prompt)] if latents is not None else None, | 108 | guidance_scale=guidance_scale, |
114 | generator=generator if latents is not None else None, | 109 | eta=eta, |
115 | guidance_scale=guidance_scale, | 110 | num_inference_steps=num_inference_steps, |
116 | eta=eta, | 111 | output_type='pil' |
117 | num_inference_steps=num_inference_steps, | 112 | ).images |
118 | output_type='pil' | 113 | |
119 | ).images | 114 | all_samples += samples |
120 | 115 | ||
121 | all_samples += samples | 116 | del samples |
122 | 117 | ||
123 | del samples | 118 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
124 | 119 | image_grid.save(file_path, quality=85) | |
125 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 120 | |
126 | image_grid.save(file_path, quality=85) | 121 | del all_samples |
127 | 122 | del image_grid | |
128 | del all_samples | ||
129 | del image_grid | ||
130 | 123 | ||
131 | del generator | 124 | del generator |
132 | del stable_latents | ||