diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol | |||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
19 | 19 | ||
20 | from data.csv import VlpnDataset | ||
20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
22 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
@@ -175,12 +176,12 @@ def generate_class_images( | |||
175 | unet: UNet2DConditionModel, | 176 | unet: UNet2DConditionModel, |
176 | tokenizer: MultiCLIPTokenizer, | 177 | tokenizer: MultiCLIPTokenizer, |
177 | sample_scheduler: DPMSolverMultistepScheduler, | 178 | sample_scheduler: DPMSolverMultistepScheduler, |
178 | data_train, | 179 | train_dataset: VlpnDataset, |
179 | sample_batch_size: int, | 180 | sample_batch_size: int, |
180 | sample_image_size: int, | 181 | sample_image_size: int, |
181 | sample_steps: int | 182 | sample_steps: int |
182 | ): | 183 | ): |
183 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 184 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] |
184 | 185 | ||
185 | if len(missing_data) == 0: | 186 | if len(missing_data) == 0: |
186 | return | 187 | return |