diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-20 13:31:11 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-20 13:31:11 +0100 |
| commit | f7d3f1e5caf675f1a0d1a172d382a0624b8d0165 (patch) | |
| tree | a8f5041ede4bf5d4aa041c8f8dd816737cbfeef1 /data | |
| parent | Fix Textual Inversion dataset filtering (diff) | |
| download | textual-inversion-diff-f7d3f1e5caf675f1a0d1a172d382a0624b8d0165.tar.gz textual-inversion-diff-f7d3f1e5caf675f1a0d1a172d382a0624b8d0165.tar.bz2 textual-inversion-diff-f7d3f1e5caf675f1a0d1a172d382a0624b8d0165.zip | |
Dependency cleanup/upgrades
Diffstat (limited to 'data')
| -rw-r--r-- | data/csv.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/data/csv.py b/data/csv.py index 6525e45..d400757 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -3,7 +3,6 @@ import torch | |||
| 3 | import json | 3 | import json |
| 4 | import numpy as np | 4 | import numpy as np |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | import pytorch_lightning as pl | ||
| 7 | from PIL import Image | 6 | from PIL import Image |
| 8 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
| 9 | from torchvision import transforms | 8 | from torchvision import transforms |
| @@ -42,7 +41,7 @@ class CSVDataItem(NamedTuple): | |||
| 42 | nprompt: str | 41 | nprompt: str |
| 43 | 42 | ||
| 44 | 43 | ||
| 45 | class CSVDataModule(pl.LightningDataModule): | 44 | class CSVDataModule(): |
| 46 | def __init__( | 45 | def __init__( |
| 47 | self, | 46 | self, |
| 48 | batch_size: int, | 47 | batch_size: int, |
| @@ -141,7 +140,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 141 | items = [ | 140 | items = [ |
| 142 | item | 141 | item |
| 143 | for item in items | 142 | for item in items |
| 144 | if "mode" in item and self.mode in item["mode"] | 143 | if "mode" in item and self.mode in item["mode"].split(", ") |
| 145 | ] | 144 | ] |
| 146 | items = self.prepare_items(template, expansions, items) | 145 | items = self.prepare_items(template, expansions, items) |
| 147 | items = self.filter_items(items) | 146 | items = self.filter_items(items) |
