diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol | |||
| 17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | 18 | from PIL import Image |
| 19 | 19 | ||
| 20 | from data.csv import VlpnDataset | ||
| 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 22 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
| @@ -175,12 +176,12 @@ def generate_class_images( | |||
| 175 | unet: UNet2DConditionModel, | 176 | unet: UNet2DConditionModel, |
| 176 | tokenizer: MultiCLIPTokenizer, | 177 | tokenizer: MultiCLIPTokenizer, |
| 177 | sample_scheduler: DPMSolverMultistepScheduler, | 178 | sample_scheduler: DPMSolverMultistepScheduler, |
| 178 | data_train, | 179 | train_dataset: VlpnDataset, |
| 179 | sample_batch_size: int, | 180 | sample_batch_size: int, |
| 180 | sample_image_size: int, | 181 | sample_image_size: int, |
| 181 | sample_steps: int | 182 | sample_steps: int |
| 182 | ): | 183 | ): |
| 183 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 184 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] |
| 184 | 185 | ||
| 185 | if len(missing_data) == 0: | 186 | if len(missing_data) == 0: |
| 186 | return | 187 | return |
