From cda7eba710dfde7b2e67964bcf76cd410c6a4a63 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 10 Apr 2023 13:42:50 +0200
Subject: Update

---
 training/functional.py          | 51 ++++++++++++++++++++++-------------------
 training/strategy/dreambooth.py |  2 +-
 training/strategy/lora.py       |  2 +-
 training/strategy/ti.py         |  2 +-
 4 files changed, 30 insertions(+), 27 deletions(-)

diff --git a/training/functional.py b/training/functional.py
index ff6d3a9..4220c79 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -143,28 +143,29 @@ def save_samples(
             for prompt in batch["nprompt_ids"]
         ]
 
-        for i in range(num_batches):
-            start = i * batch_size
-            end = (i + 1) * batch_size
-            prompt = prompt_ids[start:end]
-            nprompt = nprompt_ids[start:end]
-
-            samples = pipeline(
-                prompt=prompt,
-                negative_prompt=nprompt,
-                height=image_size,
-                width=image_size,
-                generator=gen,
-                guidance_scale=guidance_scale,
-                sag_scale=0,
-                num_inference_steps=num_steps,
-                output_type='pil'
-            ).images
-
-            all_samples += samples
-
-        image_grid = make_grid(all_samples, grid_rows, grid_cols)
-        image_grid.save(file_path, quality=85)
+        with torch.inference_mode():
+            for i in range(num_batches):
+                start = i * batch_size
+                end = (i + 1) * batch_size
+                prompt = prompt_ids[start:end]
+                nprompt = nprompt_ids[start:end]
+
+                samples = pipeline(
+                    prompt=prompt,
+                    negative_prompt=nprompt,
+                    height=image_size,
+                    width=image_size,
+                    generator=gen,
+                    guidance_scale=guidance_scale,
+                    sag_scale=0,
+                    num_inference_steps=num_steps,
+                    output_type='pil'
+                ).images
+
+                all_samples += samples
+
+            image_grid = make_grid(all_samples, grid_rows, grid_cols)
+            image_grid.save(file_path, quality=85)
 
     del generator
     del pipeline
@@ -482,7 +483,8 @@ def train_loop(
                     local_progress_bar.clear()
                     global_progress_bar.clear()
 
-                    on_sample(global_step + global_step_offset)
+                    with on_eval():
+                        on_sample(global_step + global_step_offset)
 
                 if epoch % checkpoint_frequency == 0 and epoch != 0:
                     local_progress_bar.clear()
@@ -606,7 +608,8 @@ def train_loop(
         # Create the pipeline using using the trained modules and save it.
         if accelerator.is_main_process:
             print("Finished!")
-            on_sample(global_step + global_step_offset)
+            with on_eval():
+                on_sample(global_step + global_step_offset)
             on_checkpoint(global_step + global_step_offset, "end")
 
     except KeyboardInterrupt:
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 7cdfc7f..fa51bc7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -149,7 +149,7 @@ def dreambooth_strategy_callbacks(
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
 
-    @on_eval()
+    @torch.no_grad()
     def on_sample(step):
         unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
         text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0f72a17..73ec8f2 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -146,7 +146,7 @@ def lora_strategy_callbacks(
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
 
-    @on_eval()
+    @torch.no_grad()
     def on_sample(step):
         unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
         text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f00045f..08af89d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -142,7 +142,7 @@ def textual_inversion_strategy_callbacks(
                     checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
                 )
 
-    @on_eval()
+    @torch.no_grad()
     def on_sample(step):
         unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
         text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
-- 
cgit v1.2.3-70-g09d2