From 0462f5d20fdc82806753629b6f5c5fb39a88d1d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 09:19:52 +0200 Subject: Bugfix --- data/dreambooth/csv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'data') diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4ebdc13..9075979 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -150,7 +150,7 @@ class CSVDataset(Dataset): max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_identifier: + if self.class_identifier is not None: class_image = Image.open(class_image_path) if not class_image.mode == "RGB": class_image = class_image.convert("RGB") @@ -177,7 +177,7 @@ class CSVDataset(Dataset): example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] - if self.class_identifier: + if self.class_identifier is not None: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] -- cgit v1.2.3-54-g00ecf