summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py51
1 files changed, 27 insertions, 24 deletions
diff --git a/training/functional.py b/training/functional.py
index ff6d3a9..4220c79 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -143,28 +143,29 @@ def save_samples(
143 for prompt in batch["nprompt_ids"] 143 for prompt in batch["nprompt_ids"]
144 ] 144 ]
145 145
146 for i in range(num_batches): 146 with torch.inference_mode():
147 start = i * batch_size 147 for i in range(num_batches):
148 end = (i + 1) * batch_size 148 start = i * batch_size
149 prompt = prompt_ids[start:end] 149 end = (i + 1) * batch_size
150 nprompt = nprompt_ids[start:end] 150 prompt = prompt_ids[start:end]
151 151 nprompt = nprompt_ids[start:end]
152 samples = pipeline( 152
153 prompt=prompt, 153 samples = pipeline(
154 negative_prompt=nprompt, 154 prompt=prompt,
155 height=image_size, 155 negative_prompt=nprompt,
156 width=image_size, 156 height=image_size,
157 generator=gen, 157 width=image_size,
158 guidance_scale=guidance_scale, 158 generator=gen,
159 sag_scale=0, 159 guidance_scale=guidance_scale,
160 num_inference_steps=num_steps, 160 sag_scale=0,
161 output_type='pil' 161 num_inference_steps=num_steps,
162 ).images 162 output_type='pil'
163 163 ).images
164 all_samples += samples 164
165 165 all_samples += samples
166 image_grid = make_grid(all_samples, grid_rows, grid_cols) 166
167 image_grid.save(file_path, quality=85) 167 image_grid = make_grid(all_samples, grid_rows, grid_cols)
168 image_grid.save(file_path, quality=85)
168 169
169 del generator 170 del generator
170 del pipeline 171 del pipeline
@@ -482,7 +483,8 @@ def train_loop(
482 local_progress_bar.clear() 483 local_progress_bar.clear()
483 global_progress_bar.clear() 484 global_progress_bar.clear()
484 485
485 on_sample(global_step + global_step_offset) 486 with on_eval():
487 on_sample(global_step + global_step_offset)
486 488
487 if epoch % checkpoint_frequency == 0 and epoch != 0: 489 if epoch % checkpoint_frequency == 0 and epoch != 0:
488 local_progress_bar.clear() 490 local_progress_bar.clear()
@@ -606,7 +608,8 @@ def train_loop(
606 # Create the pipeline using using the trained modules and save it. 608 # Create the pipeline using using the trained modules and save it.
607 if accelerator.is_main_process: 609 if accelerator.is_main_process:
608 print("Finished!") 610 print("Finished!")
609 on_sample(global_step + global_step_offset) 611 with on_eval():
612 on_sample(global_step + global_step_offset)
610 on_checkpoint(global_step + global_step_offset, "end") 613 on_checkpoint(global_step + global_step_offset, "end")
611 614
612 except KeyboardInterrupt: 615 except KeyboardInterrupt: