summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py
index 43b03ac..546aaff 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -2,6 +2,7 @@ from dataclasses import dataclass
2import math 2import math
3from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Optional, Protocol 4from typing import Callable, Any, Tuple, Union, Optional, Protocol
5from types import MethodType
5from functools import partial 6from functools import partial
6from pathlib import Path 7from pathlib import Path
7import itertools 8import itertools
@@ -108,6 +109,7 @@ def save_samples(
108 output_dir: Path, 109 output_dir: Path,
109 seed: int, 110 seed: int,
110 step: int, 111 step: int,
112 validation_prompts: list[str] = [],
111 cycle: int = 1, 113 cycle: int = 1,
112 batch_size: int = 1, 114 batch_size: int = 1,
113 num_batches: int = 1, 115 num_batches: int = 1,
@@ -136,7 +138,6 @@ def save_samples(
136 138
137 if val_dataloader is not None: 139 if val_dataloader is not None:
138 datasets.append(("stable", val_dataloader, generator)) 140 datasets.append(("stable", val_dataloader, generator))
139 datasets.append(("val", val_dataloader, None))
140 141
141 for pool, data, gen in datasets: 142 for pool, data, gen in datasets:
142 all_samples = [] 143 all_samples = []
@@ -165,7 +166,6 @@ def save_samples(
165 guidance_scale=guidance_scale, 166 guidance_scale=guidance_scale,
166 sag_scale=0, 167 sag_scale=0,
167 num_inference_steps=num_steps, 168 num_inference_steps=num_steps,
168 output_type=None,
169 ).images 169 ).images
170 170
171 all_samples.append(torch.from_numpy(samples)) 171 all_samples.append(torch.from_numpy(samples))
@@ -803,4 +803,7 @@ def train(
803 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 803 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
804 accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 804 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
805 805
806 text_encoder.forward = MethodType(text_encoder.forward, text_encoder)
807 unet.forward = MethodType(unet.forward, unet)
808
806 accelerator.free_memory() 809 accelerator.free_memory()