summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--training/functional.py51
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py2
4 files changed, 30 insertions, 27 deletions
diff --git a/training/functional.py b/training/functional.py
index ff6d3a9..4220c79 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -143,28 +143,29 @@ def save_samples(
143 for prompt in batch["nprompt_ids"] 143 for prompt in batch["nprompt_ids"]
144 ] 144 ]
145 145
146 for i in range(num_batches): 146 with torch.inference_mode():
147 start = i * batch_size 147 for i in range(num_batches):
148 end = (i + 1) * batch_size 148 start = i * batch_size
149 prompt = prompt_ids[start:end] 149 end = (i + 1) * batch_size
150 nprompt = nprompt_ids[start:end] 150 prompt = prompt_ids[start:end]
151 151 nprompt = nprompt_ids[start:end]
152 samples = pipeline( 152
153 prompt=prompt, 153 samples = pipeline(
154 negative_prompt=nprompt, 154 prompt=prompt,
155 height=image_size, 155 negative_prompt=nprompt,
156 width=image_size, 156 height=image_size,
157 generator=gen, 157 width=image_size,
158 guidance_scale=guidance_scale, 158 generator=gen,
159 sag_scale=0, 159 guidance_scale=guidance_scale,
160 num_inference_steps=num_steps, 160 sag_scale=0,
161 output_type='pil' 161 num_inference_steps=num_steps,
162 ).images 162 output_type='pil'
163 163 ).images
164 all_samples += samples 164
165 165 all_samples += samples
166 image_grid = make_grid(all_samples, grid_rows, grid_cols) 166
167 image_grid.save(file_path, quality=85) 167 image_grid = make_grid(all_samples, grid_rows, grid_cols)
168 image_grid.save(file_path, quality=85)
168 169
169 del generator 170 del generator
170 del pipeline 171 del pipeline
@@ -482,7 +483,8 @@ def train_loop(
482 local_progress_bar.clear() 483 local_progress_bar.clear()
483 global_progress_bar.clear() 484 global_progress_bar.clear()
484 485
485 on_sample(global_step + global_step_offset) 486 with on_eval():
487 on_sample(global_step + global_step_offset)
486 488
487 if epoch % checkpoint_frequency == 0 and epoch != 0: 489 if epoch % checkpoint_frequency == 0 and epoch != 0:
488 local_progress_bar.clear() 490 local_progress_bar.clear()
@@ -606,7 +608,8 @@ def train_loop(
606 # Create the pipeline using using the trained modules and save it. 608 # Create the pipeline using using the trained modules and save it.
607 if accelerator.is_main_process: 609 if accelerator.is_main_process:
608 print("Finished!") 610 print("Finished!")
609 on_sample(global_step + global_step_offset) 611 with on_eval():
612 on_sample(global_step + global_step_offset)
610 on_checkpoint(global_step + global_step_offset, "end") 613 on_checkpoint(global_step + global_step_offset, "end")
611 614
612 except KeyboardInterrupt: 615 except KeyboardInterrupt:
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 7cdfc7f..fa51bc7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -149,7 +149,7 @@ def dreambooth_strategy_callbacks(
149 if torch.cuda.is_available(): 149 if torch.cuda.is_available():
150 torch.cuda.empty_cache() 150 torch.cuda.empty_cache()
151 151
152 @on_eval() 152 @torch.no_grad()
153 def on_sample(step): 153 def on_sample(step):
154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 154 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 155 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 0f72a17..73ec8f2 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -146,7 +146,7 @@ def lora_strategy_callbacks(
146 if torch.cuda.is_available(): 146 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 147 torch.cuda.empty_cache()
148 148
149 @on_eval() 149 @torch.no_grad()
150 def on_sample(step): 150 def on_sample(step):
151 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 151 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
152 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 152 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f00045f..08af89d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -142,7 +142,7 @@ def textual_inversion_strategy_callbacks(
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
143 ) 143 )
144 144
145 @on_eval() 145 @torch.no_grad()
146 def on_sample(step): 146 def on_sample(step):
147 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 147 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
148 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)