summaryrefslogtreecommitdiffstats
path: root/data/dreambooth
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth')
-rw-r--r--data/dreambooth/csv.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 14c13bb..85ed4a5 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -108,14 +108,14 @@ class CSVDataset(Dataset):
108 else: 108 else:
109 self.class_data_root = None 109 self.class_data_root = None
110 110
111 self.interpolation = {"linear": PIL.Image.LINEAR, 111 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
112 "bilinear": PIL.Image.BILINEAR, 112 "bilinear": transforms.InterpolationMode.BILINEAR,
113 "bicubic": PIL.Image.BICUBIC, 113 "bicubic": transforms.InterpolationMode.BICUBIC,
114 "lanczos": PIL.Image.LANCZOS, 114 "lanczos": transforms.InterpolationMode.LANCZOS,
115 }[interpolation] 115 }[interpolation]
116 self.image_transforms = transforms.Compose( 116 self.image_transforms = transforms.Compose(
117 [ 117 [
118 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 118 transforms.Resize(size, interpolation=self.interpolation),
119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 119 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
120 transforms.RandomHorizontalFlip(), 120 transforms.RandomHorizontalFlip(),
121 transforms.ToTensor(), 121 transforms.ToTensor(),