From cda7eba710dfde7b2e67964bcf76cd410c6a4a63 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 13:42:50 +0200 Subject: Update --- training/functional.py | 51 ++++++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 24 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index ff6d3a9..4220c79 100644 --- a/training/functional.py +++ b/training/functional.py @@ -143,28 +143,29 @@ def save_samples( for prompt in batch["nprompt_ids"] ] - for i in range(num_batches): - start = i * batch_size - end = (i + 1) * batch_size - prompt = prompt_ids[start:end] - nprompt = nprompt_ids[start:end] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=image_size, - width=image_size, - generator=gen, - guidance_scale=guidance_scale, - sag_scale=0, - num_inference_steps=num_steps, - output_type='pil' - ).images - - all_samples += samples - - image_grid = make_grid(all_samples, grid_rows, grid_cols) - image_grid.save(file_path, quality=85) + with torch.inference_mode(): + for i in range(num_batches): + start = i * batch_size + end = (i + 1) * batch_size + prompt = prompt_ids[start:end] + nprompt = nprompt_ids[start:end] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=image_size, + width=image_size, + generator=gen, + guidance_scale=guidance_scale, + sag_scale=0, + num_inference_steps=num_steps, + output_type='pil' + ).images + + all_samples += samples + + image_grid = make_grid(all_samples, grid_rows, grid_cols) + image_grid.save(file_path, quality=85) del generator del pipeline @@ -482,7 +483,8 @@ def train_loop( local_progress_bar.clear() global_progress_bar.clear() - on_sample(global_step + global_step_offset) + with on_eval(): + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() @@ -606,7 +608,8 @@ def train_loop( # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - on_sample(global_step + global_step_offset) + with on_eval(): + on_sample(global_step + global_step_offset) on_checkpoint(global_step + global_step_offset, "end") except KeyboardInterrupt: -- cgit v1.2.3-54-g00ecf