diff options
| -rw-r--r-- | data/csv.py | 5 | ||||
| -rw-r--r-- | environment.yaml | 34 |
2 files changed, 12 insertions, 27 deletions
diff --git a/data/csv.py b/data/csv.py index 6525e45..d400757 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -3,7 +3,6 @@ import torch | |||
| 3 | import json | 3 | import json |
| 4 | import numpy as np | 4 | import numpy as np |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | import pytorch_lightning as pl | ||
| 7 | from PIL import Image | 6 | from PIL import Image |
| 8 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
| 9 | from torchvision import transforms | 8 | from torchvision import transforms |
| @@ -42,7 +41,7 @@ class CSVDataItem(NamedTuple): | |||
| 42 | nprompt: str | 41 | nprompt: str |
| 43 | 42 | ||
| 44 | 43 | ||
| 45 | class CSVDataModule(pl.LightningDataModule): | 44 | class CSVDataModule(): |
| 46 | def __init__( | 45 | def __init__( |
| 47 | self, | 46 | self, |
| 48 | batch_size: int, | 47 | batch_size: int, |
| @@ -141,7 +140,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 141 | items = [ | 140 | items = [ |
| 142 | item | 141 | item |
| 143 | for item in items | 142 | for item in items |
| 144 | if "mode" in item and self.mode in item["mode"] | 143 | if "mode" in item and self.mode in item["mode"].split(", ") |
| 145 | ] | 144 | ] |
| 146 | items = self.prepare_items(template, expansions, items) | 145 | items = self.prepare_items(template, expansions, items) |
| 147 | items = self.filter_items(items) | 146 | items = self.filter_items(items) |
diff --git a/environment.yaml b/environment.yaml index 57e090c..e598f72 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -5,34 +5,20 @@ channels: | |||
| 5 | - defaults | 5 | - defaults |
| 6 | dependencies: | 6 | dependencies: |
| 7 | - cudatoolkit=11.3 | 7 | - cudatoolkit=11.3 |
| 8 | - numpy=1.22.3 | 8 | - numpy=1.23.4 |
| 9 | - pip=20.3 | 9 | - pip=22.3.1 |
| 10 | - python=3.9.13 | 10 | - python=3.9.15 |
| 11 | - pytorch=1.12.1 | 11 | - pytorch=1.13.1 |
| 12 | - torchvision=0.13.1 | 12 | - torchvision=0.14.1 |
| 13 | - pandas=1.4.3 | ||
| 14 | - xformers=0.0.15.dev344 | ||
| 15 | - pip: | 13 | - pip: |
| 16 | - -e . | 14 | - -e . |
| 17 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | ||
| 18 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip | ||
| 19 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion | ||
| 20 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 15 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 21 | - accelerate==0.12.0 | 16 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion |
| 22 | - albumentations==1.1.0 | 17 | - accelerate==0.15.0 |
| 23 | - bitsandbytes==0.35.4 | 18 | - bitsandbytes==0.35.4 |
| 24 | - einops==0.4.1 | ||
| 25 | - imageio==2.22.0 | ||
| 26 | - k-diffusion==0.0.12 | ||
| 27 | - kornia==0.6 | ||
| 28 | - pudb==2019.2 | ||
| 29 | - omegaconf==2.2.3 | ||
| 30 | - opencv-python-headless==4.6.0.66 | ||
| 31 | - python-slugify>=6.1.2 | 19 | - python-slugify>=6.1.2 |
| 32 | - pytorch-lightning==1.7.7 | 20 | - setuptools==65.6.3 |
| 33 | - setuptools==59.5.0 | ||
| 34 | - test-tube>=0.7.5 | 21 | - test-tube>=0.7.5 |
| 35 | - torch-fidelity==0.3.0 | ||
| 36 | - torchmetrics==0.9.3 | ||
| 37 | - transformers==4.25.1 | 22 | - transformers==4.25.1 |
| 38 | - triton==2.0.0.dev20221105 | 23 | - triton==2.0.0.dev20221202 |
| 24 | - xformers==0.0.16rc391 | ||
