summaryrefslogtreecommitdiffstats
path: root/data
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-20 13:31:11 +0100
committerVolpeon <git@volpeon.ink>2022-12-20 13:31:11 +0100
commitf7d3f1e5caf675f1a0d1a172d382a0624b8d0165 (patch)
treea8f5041ede4bf5d4aa041c8f8dd816737cbfeef1 /data
parentFix Textual Inversion dataset filtering (diff)
downloadtextual-inversion-diff-f7d3f1e5caf675f1a0d1a172d382a0624b8d0165.tar.gz
textual-inversion-diff-f7d3f1e5caf675f1a0d1a172d382a0624b8d0165.tar.bz2
textual-inversion-diff-f7d3f1e5caf675f1a0d1a172d382a0624b8d0165.zip
Dependency cleanup/upgrades
Diffstat (limited to 'data')
-rw-r--r--data/csv.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py
index 6525e45..d400757 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -3,7 +3,6 @@ import torch
3import json 3import json
4import numpy as np 4import numpy as np
5from pathlib import Path 5from pathlib import Path
6import pytorch_lightning as pl
7from PIL import Image 6from PIL import Image
8from torch.utils.data import Dataset, DataLoader, random_split 7from torch.utils.data import Dataset, DataLoader, random_split
9from torchvision import transforms 8from torchvision import transforms
@@ -42,7 +41,7 @@ class CSVDataItem(NamedTuple):
42 nprompt: str 41 nprompt: str
43 42
44 43
45class CSVDataModule(pl.LightningDataModule): 44class CSVDataModule():
46 def __init__( 45 def __init__(
47 self, 46 self,
48 batch_size: int, 47 batch_size: int,
@@ -141,7 +140,7 @@ class CSVDataModule(pl.LightningDataModule):
141 items = [ 140 items = [
142 item 141 item
143 for item in items 142 for item in items
144 if "mode" in item and self.mode in item["mode"] 143 if "mode" in item and self.mode in item["mode"].split(", ")
145 ] 144 ]
146 items = self.prepare_items(template, expansions, items) 145 items = self.prepare_items(template, expansions, items)
147 items = self.filter_items(items) 146 items = self.filter_items(items)