summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/util.py106
1 files changed, 49 insertions, 57 deletions
diff --git a/training/util.py b/training/util.py
index 5c056a6..a0c15cd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -60,7 +60,7 @@ class CheckpointerBase:
60 self.sample_batches = sample_batches 60 self.sample_batches = sample_batches
61 self.sample_batch_size = sample_batch_size 61 self.sample_batch_size = sample_batch_size
62 62
63 @torch.no_grad() 63 @torch.inference_mode()
64 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 64 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
65 samples_path = Path(self.output_dir).joinpath("samples") 65 samples_path = Path(self.output_dir).joinpath("samples")
66 66
@@ -68,65 +68,57 @@ class CheckpointerBase:
68 val_data = self.datamodule.val_dataloader() 68 val_data = self.datamodule.val_dataloader()
69 69
70 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 70 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
71 stable_latents = torch.randn(
72 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
73 device=pipeline.device,
74 generator=generator,
75 )
76 71
77 grid_cols = min(self.sample_batch_size, 4) 72 grid_cols = min(self.sample_batch_size, 4)
78 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols 73 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
79 74
80 with torch.autocast("cuda"), torch.inference_mode(): 75 for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]:
81 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 76 all_samples = []
82 all_samples = [] 77 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 78 file_path.parent.mkdir(parents=True, exist_ok=True)
84 file_path.parent.mkdir(parents=True, exist_ok=True) 79
85 80 data_enum = enumerate(data)
86 data_enum = enumerate(data) 81
87 82 batches = [
88 batches = [ 83 batch
89 batch 84 for j, batch in data_enum
90 for j, batch in data_enum 85 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 if j * data.batch_size < self.sample_batch_size * self.sample_batches 86 ]
92 ] 87 prompts = [
93 prompts = [ 88 prompt
94 prompt 89 for batch in batches
95 for batch in batches 90 for prompt in batch["prompts"]
96 for prompt in batch["prompts"] 91 ]
97 ] 92 nprompts = [
98 nprompts = [ 93 prompt
99 prompt 94 for batch in batches
100 for batch in batches 95 for prompt in batch["nprompts"]
101 for prompt in batch["nprompts"] 96 ]
102 ] 97
103 98 for i in range(self.sample_batches):
104 for i in range(self.sample_batches): 99 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
105 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 100 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
106 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 101
107 102 samples = pipeline(
108 samples = pipeline( 103 prompt=prompt,
109 prompt=prompt, 104 negative_prompt=nprompt,
110 negative_prompt=nprompt, 105 height=self.sample_image_size,
111 height=self.sample_image_size, 106 width=self.sample_image_size,
112 width=self.sample_image_size, 107 generator=gen,
113 image=latents[:len(prompt)] if latents is not None else None, 108 guidance_scale=guidance_scale,
114 generator=generator if latents is not None else None, 109 eta=eta,
115 guidance_scale=guidance_scale, 110 num_inference_steps=num_inference_steps,
116 eta=eta, 111 output_type='pil'
117 num_inference_steps=num_inference_steps, 112 ).images
118 output_type='pil' 113
119 ).images 114 all_samples += samples
120 115
121 all_samples += samples 116 del samples
122 117
123 del samples 118 image_grid = make_grid(all_samples, grid_rows, grid_cols)
124 119 image_grid.save(file_path, quality=85)
125 image_grid = make_grid(all_samples, grid_rows, grid_cols) 120
126 image_grid.save(file_path, quality=85) 121 del all_samples
127 122 del image_grid
128 del all_samples
129 del image_grid
130 123
131 del generator 124 del generator
132 del stable_latents