From cda7eba710dfde7b2e67964bcf76cd410c6a4a63 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 13:42:50 +0200 Subject: Update --- training/functional.py | 51 ++++++++++++++++++++++------------------- training/strategy/dreambooth.py | 2 +- training/strategy/lora.py | 2 +- training/strategy/ti.py | 2 +- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/training/functional.py b/training/functional.py index ff6d3a9..4220c79 100644 --- a/training/functional.py +++ b/training/functional.py @@ -143,28 +143,29 @@ def save_samples( for prompt in batch["nprompt_ids"] ] - for i in range(num_batches): - start = i * batch_size - end = (i + 1) * batch_size - prompt = prompt_ids[start:end] - nprompt = nprompt_ids[start:end] - - samples = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=image_size, - width=image_size, - generator=gen, - guidance_scale=guidance_scale, - sag_scale=0, - num_inference_steps=num_steps, - output_type='pil' - ).images - - all_samples += samples - - image_grid = make_grid(all_samples, grid_rows, grid_cols) - image_grid.save(file_path, quality=85) + with torch.inference_mode(): + for i in range(num_batches): + start = i * batch_size + end = (i + 1) * batch_size + prompt = prompt_ids[start:end] + nprompt = nprompt_ids[start:end] + + samples = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=image_size, + width=image_size, + generator=gen, + guidance_scale=guidance_scale, + sag_scale=0, + num_inference_steps=num_steps, + output_type='pil' + ).images + + all_samples += samples + + image_grid = make_grid(all_samples, grid_rows, grid_cols) + image_grid.save(file_path, quality=85) del generator del pipeline @@ -482,7 +483,8 @@ def train_loop( local_progress_bar.clear() global_progress_bar.clear() - on_sample(global_step + global_step_offset) + with on_eval(): + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: local_progress_bar.clear() @@ -606,7 +608,8 @@ def train_loop( # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - on_sample(global_step + global_step_offset) + with on_eval(): + on_sample(global_step + global_step_offset) on_checkpoint(global_step + global_step_offset, "end") except KeyboardInterrupt: diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 7cdfc7f..fa51bc7 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -149,7 +149,7 @@ def dreambooth_strategy_callbacks( if torch.cuda.is_available(): torch.cuda.empty_cache() - @on_eval() + @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 0f72a17..73ec8f2 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -146,7 +146,7 @@ def lora_strategy_callbacks( if torch.cuda.is_available(): torch.cuda.empty_cache() - @on_eval() + @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f00045f..08af89d 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -142,7 +142,7 @@ def textual_inversion_strategy_callbacks( checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" ) - @on_eval() + @torch.no_grad() def on_sample(step): unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) -- cgit v1.2.3-70-g09d2