summaryrefslogtreecommitdiffstats
path: root/data/dreambooth
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth')
-rw-r--r--data/dreambooth/csv.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 4ebdc13..9075979 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -150,7 +150,7 @@ class CSVDataset(Dataset):
150 max_length=self.tokenizer.model_max_length, 150 max_length=self.tokenizer.model_max_length,
151 ).input_ids 151 ).input_ids
152 152
153 if self.class_identifier: 153 if self.class_identifier is not None:
154 class_image = Image.open(class_image_path) 154 class_image = Image.open(class_image_path)
155 if not class_image.mode == "RGB": 155 if not class_image.mode == "RGB":
156 class_image = class_image.convert("RGB") 156 class_image = class_image.convert("RGB")
@@ -177,7 +177,7 @@ class CSVDataset(Dataset):
177 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 177 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
178 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 178 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
179 179
180 if self.class_identifier: 180 if self.class_identifier is not None:
181 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 181 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
182 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 182 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
183 183