summaryrefslogtreecommitdiffstats
path: root/data/dreambooth
diff options
context:
space:
mode:
Diffstat (limited to 'data/dreambooth')
-rw-r--r--data/dreambooth/csv.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 85ed4a5..99bcf12 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -1,3 +1,4 @@
1import math
1import os 2import os
2import pandas as pd 3import pandas as pd
3from pathlib import Path 4from pathlib import Path
@@ -57,11 +58,10 @@ class CSVDataModule(pl.LightningDataModule):
57 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 58 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
58 class_data_root=self.class_data_root, class_prompt=self.class_prompt, 59 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
59 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 60 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
60 center_crop=self.center_crop, repeats=self.repeats) 61 center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, 62 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt,
62 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
63 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 63 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
64 center_crop=self.center_crop) 64 center_crop=self.center_crop, batch_size=self.batch_size)
65 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 65 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
66 shuffle=True, collate_fn=self.collate_fn) 66 shuffle=True, collate_fn=self.collate_fn)
67 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 67 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
@@ -85,22 +85,24 @@ class CSVDataset(Dataset):
85 interpolation="bicubic", 85 interpolation="bicubic",
86 identifier="*", 86 identifier="*",
87 center_crop=False, 87 center_crop=False,
88 batch_size=1,
88 ): 89 ):
89 90
90 self.data = data 91 self.data = data
91 self.tokenizer = tokenizer 92 self.tokenizer = tokenizer
92 self.instance_prompt = instance_prompt 93 self.instance_prompt = instance_prompt
94 self.identifier = identifier
95 self.batch_size = batch_size
96 self.cache = {}
93 97
94 self.num_instance_images = len(self.data) 98 self.num_instance_images = len(self.data)
95 self._length = self.num_instance_images * repeats 99 self._length = self.num_instance_images * repeats
96 100
97 self.identifier = identifier
98
99 if class_data_root is not None: 101 if class_data_root is not None:
100 self.class_data_root = Path(class_data_root) 102 self.class_data_root = Path(class_data_root)
101 self.class_data_root.mkdir(parents=True, exist_ok=True) 103 self.class_data_root.mkdir(parents=True, exist_ok=True)
102 104
103 self.class_images = list(Path(class_data_root).iterdir()) 105 self.class_images = list(self.class_data_root.iterdir())
104 self.num_class_images = len(self.class_images) 106 self.num_class_images = len(self.class_images)
105 self._length = max(self.num_class_images, self.num_instance_images) 107 self._length = max(self.num_class_images, self.num_instance_images)
106 108
@@ -123,10 +125,8 @@ class CSVDataset(Dataset):
123 ] 125 ]
124 ) 126 )
125 127
126 self.cache = {}
127
128 def __len__(self): 128 def __len__(self):
129 return self._length 129 return math.ceil(self._length / self.batch_size) * self.batch_size
130 130
131 def get_example(self, i): 131 def get_example(self, i):
132 image_path, text = self.data[i % self.num_instance_images] 132 image_path, text = self.data[i % self.num_instance_images]