diff options
Diffstat (limited to 'data/dreambooth')
-rw-r--r-- | data/dreambooth/csv.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4ebdc13..9075979 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -150,7 +150,7 @@ class CSVDataset(Dataset): | |||
150 | max_length=self.tokenizer.model_max_length, | 150 | max_length=self.tokenizer.model_max_length, |
151 | ).input_ids | 151 | ).input_ids |
152 | 152 | ||
153 | if self.class_identifier: | 153 | if self.class_identifier is not None: |
154 | class_image = Image.open(class_image_path) | 154 | class_image = Image.open(class_image_path) |
155 | if not class_image.mode == "RGB": | 155 | if not class_image.mode == "RGB": |
156 | class_image = class_image.convert("RGB") | 156 | class_image = class_image.convert("RGB") |
@@ -177,7 +177,7 @@ class CSVDataset(Dataset): | |||
177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
179 | 179 | ||
180 | if self.class_identifier: | 180 | if self.class_identifier is not None: |
181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
183 | 183 | ||