From 6c8cffe28baeafac77d047ff3f8ded9418033e2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 15:52:43 +0100 Subject: More training adjustments --- training/functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py @@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol from tqdm.auto import tqdm from PIL import Image +from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings @@ -175,12 +176,12 @@ def generate_class_images( unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: DPMSolverMultistepScheduler, - data_train, + train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int ): - missing_data = [item for item in data_train if not item.class_image_path.exists()] + missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] if len(missing_data) == 0: return -- cgit v1.2.3-54-g00ecf