diff options
Diffstat (limited to 'data/dreambooth')
-rw-r--r-- | data/dreambooth/csv.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 14c13bb..85ed4a5 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -108,14 +108,14 @@ class CSVDataset(Dataset): | |||
108 | else: | 108 | else: |
109 | self.class_data_root = None | 109 | self.class_data_root = None |
110 | 110 | ||
111 | self.interpolation = {"linear": PIL.Image.LINEAR, | 111 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
112 | "bilinear": PIL.Image.BILINEAR, | 112 | "bilinear": transforms.InterpolationMode.BILINEAR, |
113 | "bicubic": PIL.Image.BICUBIC, | 113 | "bicubic": transforms.InterpolationMode.BICUBIC, |
114 | "lanczos": PIL.Image.LANCZOS, | 114 | "lanczos": transforms.InterpolationMode.LANCZOS, |
115 | }[interpolation] | 115 | }[interpolation] |
116 | self.image_transforms = transforms.Compose( | 116 | self.image_transforms = transforms.Compose( |
117 | [ | 117 | [ |
118 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), | 118 | transforms.Resize(size, interpolation=self.interpolation), |
119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 119 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), |
120 | transforms.RandomHorizontalFlip(), | 120 | transforms.RandomHorizontalFlip(), |
121 | transforms.ToTensor(), | 121 | transforms.ToTensor(), |