diff options
-rw-r--r-- | .gitignore | 164 | ||||
-rw-r--r-- | .pep8 | 2 | ||||
-rw-r--r-- | data.py | 145 | ||||
-rw-r--r-- | environment.yaml | 36 | ||||
-rw-r--r-- | main.py | 784 | ||||
-rw-r--r-- | pipelines/stable_diffusion/no_check.py | 13 | ||||
-rw-r--r-- | scripts/convert_original_stable_diffusion_to_diffusers.py | 690 | ||||
-rw-r--r-- | setup.py | 13 |
8 files changed, 1847 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..00e7681 --- /dev/null +++ b/.gitignore | |||
@@ -0,0 +1,164 @@ | |||
1 | # Byte-compiled / optimized / DLL files | ||
2 | __pycache__/ | ||
3 | *.py[cod] | ||
4 | *$py.class | ||
5 | |||
6 | # C extensions | ||
7 | *.so | ||
8 | |||
9 | # Distribution / packaging | ||
10 | .Python | ||
11 | build/ | ||
12 | develop-eggs/ | ||
13 | dist/ | ||
14 | downloads/ | ||
15 | eggs/ | ||
16 | .eggs/ | ||
17 | lib/ | ||
18 | lib64/ | ||
19 | parts/ | ||
20 | sdist/ | ||
21 | var/ | ||
22 | wheels/ | ||
23 | share/python-wheels/ | ||
24 | *.egg-info/ | ||
25 | .installed.cfg | ||
26 | *.egg | ||
27 | MANIFEST | ||
28 | |||
29 | # PyInstaller | ||
30 | # Usually these files are written by a python script from a template | ||
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
32 | *.manifest | ||
33 | *.spec | ||
34 | |||
35 | # Installer logs | ||
36 | pip-log.txt | ||
37 | pip-delete-this-directory.txt | ||
38 | |||
39 | # Unit test / coverage reports | ||
40 | htmlcov/ | ||
41 | .tox/ | ||
42 | .nox/ | ||
43 | .coverage | ||
44 | .coverage.* | ||
45 | .cache | ||
46 | nosetests.xml | ||
47 | coverage.xml | ||
48 | *.cover | ||
49 | *.py,cover | ||
50 | .hypothesis/ | ||
51 | .pytest_cache/ | ||
52 | cover/ | ||
53 | |||
54 | # Translations | ||
55 | *.mo | ||
56 | *.pot | ||
57 | |||
58 | # Django stuff: | ||
59 | *.log | ||
60 | local_settings.py | ||
61 | db.sqlite3 | ||
62 | db.sqlite3-journal | ||
63 | |||
64 | # Flask stuff: | ||
65 | instance/ | ||
66 | .webassets-cache | ||
67 | |||
68 | # Scrapy stuff: | ||
69 | .scrapy | ||
70 | |||
71 | # Sphinx documentation | ||
72 | docs/_build/ | ||
73 | |||
74 | # PyBuilder | ||
75 | .pybuilder/ | ||
76 | target/ | ||
77 | |||
78 | # Jupyter Notebook | ||
79 | .ipynb_checkpoints | ||
80 | |||
81 | # IPython | ||
82 | profile_default/ | ||
83 | ipython_config.py | ||
84 | |||
85 | # pyenv | ||
86 | # For a library or package, you might want to ignore these files since the code is | ||
87 | # intended to run in multiple environments; otherwise, check them in: | ||
88 | # .python-version | ||
89 | |||
90 | # pipenv | ||
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
94 | # install all needed dependencies. | ||
95 | #Pipfile.lock | ||
96 | |||
97 | # poetry | ||
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more | ||
100 | # commonly ignored for libraries. | ||
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
102 | #poetry.lock | ||
103 | |||
104 | # pdm | ||
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
106 | #pdm.lock | ||
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
108 | # in version control. | ||
109 | # https://pdm.fming.dev/#use-with-ide | ||
110 | .pdm.toml | ||
111 | |||
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
113 | __pypackages__/ | ||
114 | |||
115 | # Celery stuff | ||
116 | celerybeat-schedule | ||
117 | celerybeat.pid | ||
118 | |||
119 | # SageMath parsed files | ||
120 | *.sage.py | ||
121 | |||
122 | # Environments | ||
123 | .env | ||
124 | .venv | ||
125 | env/ | ||
126 | venv/ | ||
127 | ENV/ | ||
128 | env.bak/ | ||
129 | venv.bak/ | ||
130 | |||
131 | # Spyder project settings | ||
132 | .spyderproject | ||
133 | .spyproject | ||
134 | |||
135 | # Rope project settings | ||
136 | .ropeproject | ||
137 | |||
138 | # mkdocs documentation | ||
139 | /site | ||
140 | |||
141 | # mypy | ||
142 | .mypy_cache/ | ||
143 | .dmypy.json | ||
144 | dmypy.json | ||
145 | |||
146 | # Pyre type checker | ||
147 | .pyre/ | ||
148 | |||
149 | # pytype static type analyzer | ||
150 | .pytype/ | ||
151 | |||
152 | # Cython debug symbols | ||
153 | cython_debug/ | ||
154 | |||
155 | # PyCharm | ||
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear | ||
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
160 | #.idea/ | ||
161 | |||
162 | text-inversion-model/ | ||
163 | conf.json | ||
164 | v1-inference.yaml | ||
@@ -0,0 +1,2 @@ | |||
1 | [pycodestyle] | ||
2 | max_line_length = 120 | ||
@@ -0,0 +1,145 @@ | |||
1 | import os | ||
2 | import numpy as np | ||
3 | import pandas as pd | ||
4 | import random | ||
5 | import PIL | ||
6 | import pytorch_lightning as pl | ||
7 | from PIL import Image | ||
8 | import torch | ||
9 | from torch.utils.data import Dataset, DataLoader, random_split | ||
10 | from torchvision import transforms | ||
11 | |||
12 | |||
13 | class CSVDataModule(pl.LightningDataModule): | ||
14 | def __init__(self, | ||
15 | batch_size, | ||
16 | data_root, | ||
17 | tokenizer, | ||
18 | size=512, | ||
19 | repeats=100, | ||
20 | interpolation="bicubic", | ||
21 | placeholder_token="*", | ||
22 | flip_p=0.5, | ||
23 | center_crop=False): | ||
24 | super().__init__() | ||
25 | |||
26 | self.data_root = data_root | ||
27 | self.tokenizer = tokenizer | ||
28 | self.size = size | ||
29 | self.repeats = repeats | ||
30 | self.placeholder_token = placeholder_token | ||
31 | self.center_crop = center_crop | ||
32 | self.flip_p = flip_p | ||
33 | self.interpolation = interpolation | ||
34 | |||
35 | self.batch_size = batch_size | ||
36 | |||
37 | def prepare_data(self): | ||
38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | ||
39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | ||
40 | captions = [caption for caption in metadata['caption'].values] | ||
41 | skips = [skip for skip in metadata['skip'].values] | ||
42 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | ||
43 | |||
44 | def setup(self, stage=None): | ||
45 | train_set_size = int(len(self.data_full) * 0.8) | ||
46 | valid_set_size = len(self.data_full) - train_set_size | ||
47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | ||
48 | |||
49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | ||
50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | ||
52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | ||
54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | ||
55 | |||
56 | def train_dataloader(self): | ||
57 | return self.train_dataloader_ | ||
58 | |||
59 | def val_dataloader(self): | ||
60 | return self.val_dataloader_ | ||
61 | |||
62 | |||
63 | class CSVDataset(Dataset): | ||
64 | def __init__(self, | ||
65 | data, | ||
66 | tokenizer, | ||
67 | size=512, | ||
68 | repeats=1, | ||
69 | interpolation="bicubic", | ||
70 | flip_p=0.5, | ||
71 | placeholder_token="*", | ||
72 | center_crop=False, | ||
73 | ): | ||
74 | |||
75 | self.data = data | ||
76 | self.tokenizer = tokenizer | ||
77 | |||
78 | self.num_images = len(self.data) | ||
79 | self._length = self.num_images * repeats | ||
80 | |||
81 | self.placeholder_token = placeholder_token | ||
82 | |||
83 | self.size = size | ||
84 | self.center_crop = center_crop | ||
85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
86 | "bilinear": PIL.Image.BILINEAR, | ||
87 | "bicubic": PIL.Image.BICUBIC, | ||
88 | "lanczos": PIL.Image.LANCZOS, | ||
89 | }[interpolation] | ||
90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
91 | |||
92 | self.cache = {} | ||
93 | |||
94 | def __len__(self): | ||
95 | return self._length | ||
96 | |||
97 | def get_example(self, i, flipped): | ||
98 | image_path, text = self.data[i % self.num_images] | ||
99 | |||
100 | if image_path in self.cache: | ||
101 | return self.cache[image_path] | ||
102 | |||
103 | example = {} | ||
104 | image = Image.open(image_path) | ||
105 | |||
106 | if not image.mode == "RGB": | ||
107 | image = image.convert("RGB") | ||
108 | |||
109 | text = text.format(self.placeholder_token) | ||
110 | |||
111 | example["prompt"] = text | ||
112 | example["input_ids"] = self.tokenizer( | ||
113 | text, | ||
114 | padding="max_length", | ||
115 | truncation=True, | ||
116 | max_length=self.tokenizer.model_max_length, | ||
117 | return_tensors="pt", | ||
118 | ).input_ids[0] | ||
119 | |||
120 | # default to score-sde preprocessing | ||
121 | img = np.array(image).astype(np.uint8) | ||
122 | |||
123 | if self.center_crop: | ||
124 | crop = min(img.shape[0], img.shape[1]) | ||
125 | h, w, = img.shape[0], img.shape[1] | ||
126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
127 | (w - crop) // 2:(w + crop) // 2] | ||
128 | |||
129 | image = Image.fromarray(img) | ||
130 | image = image.resize((self.size, self.size), | ||
131 | resample=self.interpolation) | ||
132 | image = self.flip(image) | ||
133 | image = np.array(image).astype(np.uint8) | ||
134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
135 | |||
136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
138 | |||
139 | self.cache[image_path] = example | ||
140 | return example | ||
141 | |||
142 | def __getitem__(self, i): | ||
143 | flipped = random.choice([False, True]) | ||
144 | example = self.get_example(i, flipped) | ||
145 | return example | ||
diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..a460158 --- /dev/null +++ b/environment.yaml | |||
@@ -0,0 +1,36 @@ | |||
1 | name: ldd | ||
2 | channels: | ||
3 | - pytorch | ||
4 | - defaults | ||
5 | dependencies: | ||
6 | - cudatoolkit=11.3 | ||
7 | - numpy=1.22.3 | ||
8 | - pip=20.3 | ||
9 | - python=3.8.10 | ||
10 | - pytorch=1.12.1 | ||
11 | - torchvision=0.13.1 | ||
12 | - pandas=1.4.3 | ||
13 | - pip: | ||
14 | - -e . | ||
15 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | ||
16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip | ||
17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion | ||
18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion | ||
19 | - accelerate==0.12.0 | ||
20 | - albumentations==1.1.0 | ||
21 | - diffusers==0.3.0 | ||
22 | - einops==0.4.1 | ||
23 | - imageio-ffmpeg==0.4.7 | ||
24 | - imageio==2.14.1 | ||
25 | - kornia==0.6 | ||
26 | - pudb==2019.2 | ||
27 | - omegaconf==2.1.1 | ||
28 | - opencv-python-headless==4.6.0.66 | ||
29 | - python-slugify>=6.1.2 | ||
30 | - pytorch-lightning==1.7.7 | ||
31 | - setuptools==59.5.0 | ||
32 | - streamlit>=0.73.1 | ||
33 | - test-tube>=0.7.5 | ||
34 | - torch-fidelity==0.3.0 | ||
35 | - torchmetrics==0.9.3 | ||
36 | - transformers==4.19.2 | ||
@@ -0,0 +1,784 @@ | |||
1 | import argparse | ||
2 | import itertools | ||
3 | import math | ||
4 | import os | ||
5 | import random | ||
6 | import datetime | ||
7 | from pathlib import Path | ||
8 | from typing import Optional | ||
9 | |||
10 | import numpy as np | ||
11 | import torch | ||
12 | import torch.nn as nn | ||
13 | import torch.nn.functional as F | ||
14 | import torch.utils.checkpoint | ||
15 | from torch.utils.data import Dataset | ||
16 | |||
17 | import PIL | ||
18 | from accelerate import Accelerator | ||
19 | from accelerate.logging import get_logger | ||
20 | from accelerate.utils import LoggerType, set_seed | ||
21 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | ||
22 | from diffusers.optimization import get_scheduler | ||
23 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
24 | from einops import rearrange | ||
25 | from pipelines.stable_diffusion.no_check import NoCheck | ||
26 | from huggingface_hub import HfFolder, Repository, whoami | ||
27 | from PIL import Image | ||
28 | from tqdm.auto import tqdm | ||
29 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
30 | from slugify import slugify | ||
31 | import json | ||
32 | import os | ||
33 | import sys | ||
34 | |||
35 | from data import CSVDataModule | ||
36 | |||
37 | logger = get_logger(__name__) | ||
38 | |||
39 | |||
40 | def parse_args(): | ||
41 | parser = argparse.ArgumentParser( | ||
42 | description="Simple example of a training script.") | ||
43 | parser.add_argument( | ||
44 | "--pretrained_model_name_or_path", | ||
45 | type=str, | ||
46 | default=None, | ||
47 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
48 | ) | ||
49 | parser.add_argument( | ||
50 | "--tokenizer_name", | ||
51 | type=str, | ||
52 | default=None, | ||
53 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
54 | ) | ||
55 | parser.add_argument( | ||
56 | "--train_data_dir", type=str, default=None, help="A folder containing the training data." | ||
57 | ) | ||
58 | parser.add_argument( | ||
59 | "--placeholder_token", | ||
60 | type=str, | ||
61 | default=None, | ||
62 | help="A token to use as a placeholder for the concept.", | ||
63 | ) | ||
64 | parser.add_argument( | ||
65 | "--initializer_token", type=str, default=None, help="A token to use as initializer word." | ||
66 | ) | ||
67 | parser.add_argument( | ||
68 | "--vectors_per_token", type=int, default=1, help="Vectors per token." | ||
69 | ) | ||
70 | parser.add_argument("--repeats", type=int, default=100, | ||
71 | help="How many times to repeat the training data.") | ||
72 | parser.add_argument( | ||
73 | "--output_dir", | ||
74 | type=str, | ||
75 | default="text-inversion-model", | ||
76 | help="The output directory where the model predictions and checkpoints will be written.", | ||
77 | ) | ||
78 | parser.add_argument("--seed", type=int, default=None, | ||
79 | help="A seed for reproducible training.") | ||
80 | parser.add_argument( | ||
81 | "--resolution", | ||
82 | type=int, | ||
83 | default=512, | ||
84 | help=( | ||
85 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
86 | " resolution" | ||
87 | ), | ||
88 | ) | ||
89 | parser.add_argument( | ||
90 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" | ||
91 | ) | ||
92 | parser.add_argument( | ||
93 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." | ||
94 | ) | ||
95 | parser.add_argument("--num_train_epochs", type=int, default=100) | ||
96 | parser.add_argument( | ||
97 | "--max_train_steps", | ||
98 | type=int, | ||
99 | default=5000, | ||
100 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
101 | ) | ||
102 | parser.add_argument( | ||
103 | "--gradient_accumulation_steps", | ||
104 | type=int, | ||
105 | default=1, | ||
106 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
107 | ) | ||
108 | parser.add_argument( | ||
109 | "--learning_rate", | ||
110 | type=float, | ||
111 | default=1e-4, | ||
112 | help="Initial learning rate (after the potential warmup period) to use.", | ||
113 | ) | ||
114 | parser.add_argument( | ||
115 | "--scale_lr", | ||
116 | action="store_true", | ||
117 | default=True, | ||
118 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
119 | ) | ||
120 | parser.add_argument( | ||
121 | "--lr_scheduler", | ||
122 | type=str, | ||
123 | default="constant", | ||
124 | help=( | ||
125 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
126 | ' "constant", "constant_with_warmup"]' | ||
127 | ), | ||
128 | ) | ||
129 | parser.add_argument( | ||
130 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | ||
131 | ) | ||
132 | parser.add_argument("--adam_beta1", type=float, default=0.9, | ||
133 | help="The beta1 parameter for the Adam optimizer.") | ||
134 | parser.add_argument("--adam_beta2", type=float, default=0.999, | ||
135 | help="The beta2 parameter for the Adam optimizer.") | ||
136 | parser.add_argument("--adam_weight_decay", type=float, | ||
137 | default=1e-2, help="Weight decay to use.") | ||
138 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, | ||
139 | help="Epsilon value for the Adam optimizer") | ||
140 | parser.add_argument( | ||
141 | "--mixed_precision", | ||
142 | type=str, | ||
143 | default="no", | ||
144 | choices=["no", "fp16", "bf16"], | ||
145 | help=( | ||
146 | "Whether to use mixed precision. Choose" | ||
147 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
148 | "and an Nvidia Ampere GPU." | ||
149 | ), | ||
150 | ) | ||
151 | parser.add_argument("--local_rank", type=int, default=-1, | ||
152 | help="For distributed training: local_rank") | ||
153 | parser.add_argument( | ||
154 | "--checkpoint_frequency", | ||
155 | type=int, | ||
156 | default=500, | ||
157 | help="How often to save a checkpoint and sample image", | ||
158 | ) | ||
159 | parser.add_argument( | ||
160 | "--sample_image_size", | ||
161 | type=int, | ||
162 | default=512, | ||
163 | help="Size of sample images", | ||
164 | ) | ||
165 | parser.add_argument( | ||
166 | "--stable_sample_batches", | ||
167 | type=int, | ||
168 | default=1, | ||
169 | help="Number of fixed seed sample batches to generate per checkpoint", | ||
170 | ) | ||
171 | parser.add_argument( | ||
172 | "--random_sample_batches", | ||
173 | type=int, | ||
174 | default=1, | ||
175 | help="Number of random seed sample batches to generate per checkpoint", | ||
176 | ) | ||
177 | parser.add_argument( | ||
178 | "--sample_batch_size", | ||
179 | type=int, | ||
180 | default=1, | ||
181 | help="Number of samples to generate per batch", | ||
182 | ) | ||
183 | parser.add_argument( | ||
184 | "--sample_steps", | ||
185 | type=int, | ||
186 | default=50, | ||
187 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
188 | ) | ||
189 | parser.add_argument( | ||
190 | "--resume_from", | ||
191 | type=str, | ||
192 | default=None, | ||
193 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | ||
194 | ) | ||
195 | parser.add_argument( | ||
196 | "--resume_checkpoint", | ||
197 | type=str, | ||
198 | default=None, | ||
199 | help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." | ||
200 | ) | ||
201 | parser.add_argument( | ||
202 | "--config", | ||
203 | type=str, | ||
204 | default=None, | ||
205 | help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." | ||
206 | ) | ||
207 | |||
208 | args = parser.parse_args() | ||
209 | if args.resume_from is not None: | ||
210 | with open(f"{args.resume_from}/resume.json", 'rt') as f: | ||
211 | args = parser.parse_args( | ||
212 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
213 | elif args.config is not None: | ||
214 | with open(args.config, 'rt') as f: | ||
215 | args = parser.parse_args( | ||
216 | namespace=argparse.Namespace(**json.load(f))) | ||
217 | |||
218 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
219 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
220 | args.local_rank = env_local_rank | ||
221 | |||
222 | if args.train_data_dir is None: | ||
223 | raise ValueError("You must specify --train_data_dir") | ||
224 | |||
225 | if args.pretrained_model_name_or_path is None: | ||
226 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
227 | |||
228 | if args.placeholder_token is None: | ||
229 | raise ValueError("You must specify --placeholder_token") | ||
230 | |||
231 | if args.initializer_token is None: | ||
232 | raise ValueError("You must specify --initializer_token") | ||
233 | |||
234 | if args.output_dir is None: | ||
235 | raise ValueError("You must specify --output_dir") | ||
236 | |||
237 | return args | ||
238 | |||
239 | |||
240 | def freeze_params(params): | ||
241 | for param in params: | ||
242 | param.requires_grad = False | ||
243 | |||
244 | |||
245 | def save_resume_file(basepath, args, extra={}): | ||
246 | info = {"args": vars(args)} | ||
247 | info["args"].update(extra) | ||
248 | with open(f"{basepath}/resume.json", "w") as f: | ||
249 | json.dump(info, f, indent=4) | ||
250 | |||
251 | |||
252 | def make_grid(images, rows, cols): | ||
253 | w, h = images[0].size | ||
254 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
255 | for i, image in enumerate(images): | ||
256 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
257 | return grid | ||
258 | |||
259 | |||
260 | class Checkpointer: | ||
261 | def __init__( | ||
262 | self, | ||
263 | datamodule, | ||
264 | accelerator, | ||
265 | vae, | ||
266 | unet, | ||
267 | tokenizer, | ||
268 | placeholder_token, | ||
269 | placeholder_token_id, | ||
270 | output_dir, | ||
271 | sample_image_size, | ||
272 | random_sample_batches, | ||
273 | sample_batch_size, | ||
274 | stable_sample_batches, | ||
275 | seed | ||
276 | ): | ||
277 | self.datamodule = datamodule | ||
278 | self.accelerator = accelerator | ||
279 | self.vae = vae | ||
280 | self.unet = unet | ||
281 | self.tokenizer = tokenizer | ||
282 | self.placeholder_token = placeholder_token | ||
283 | self.placeholder_token_id = placeholder_token_id | ||
284 | self.output_dir = output_dir | ||
285 | self.sample_image_size = sample_image_size | ||
286 | self.seed = seed | ||
287 | self.random_sample_batches = random_sample_batches | ||
288 | self.sample_batch_size = sample_batch_size | ||
289 | self.stable_sample_batches = stable_sample_batches | ||
290 | |||
291 | @torch.no_grad() | ||
292 | def checkpoint(self, step, text_encoder, save_samples=True, path=None): | ||
293 | print("Saving checkpoint for step %d..." % step) | ||
294 | with self.accelerator.autocast(): | ||
295 | if path is None: | ||
296 | checkpoints_path = f"{self.output_dir}/checkpoints" | ||
297 | os.makedirs(checkpoints_path, exist_ok=True) | ||
298 | |||
299 | unwrapped = self.accelerator.unwrap_model(text_encoder) | ||
300 | |||
301 | # Save a checkpoint | ||
302 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
303 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
304 | |||
305 | filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) | ||
306 | if path is not None: | ||
307 | torch.save(learned_embeds_dict, path) | ||
308 | else: | ||
309 | torch.save(learned_embeds_dict, | ||
310 | f"{checkpoints_path}/{filename}") | ||
311 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | ||
312 | del unwrapped | ||
313 | del learned_embeds | ||
314 | |||
315 | @torch.no_grad() | ||
316 | def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): | ||
317 | samples_path = f"{self.output_dir}/samples/{mode}" | ||
318 | os.makedirs(samples_path, exist_ok=True) | ||
319 | checker = NoCheck() | ||
320 | |||
321 | unwrapped = self.accelerator.unwrap_model(text_encoder) | ||
322 | # Save a sample image | ||
323 | pipeline = StableDiffusionPipeline( | ||
324 | text_encoder=unwrapped, | ||
325 | vae=self.vae, | ||
326 | unet=self.unet, | ||
327 | tokenizer=self.tokenizer, | ||
328 | scheduler=LMSDiscreteScheduler( | ||
329 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
330 | ), | ||
331 | safety_checker=NoCheck(), | ||
332 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
333 | ).to(self.accelerator.device) | ||
334 | pipeline.enable_attention_slicing() | ||
335 | |||
336 | data = { | ||
337 | "training": self.datamodule.train_dataloader(), | ||
338 | "validation": self.datamodule.val_dataloader(), | ||
339 | }[mode] | ||
340 | |||
341 | if mode == "validation" and self.stable_sample_batches > 0: | ||
342 | stable_latents = torch.randn( | ||
343 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | ||
344 | device=pipeline.device, | ||
345 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | ||
346 | ) | ||
347 | |||
348 | all_samples = [] | ||
349 | filename = f"stable_step_%d.png" % (step) | ||
350 | |||
351 | # Generate and save stable samples | ||
352 | for i in range(0, self.stable_sample_batches): | ||
353 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | ||
354 | samples = pipeline( | ||
355 | prompt=prompt, | ||
356 | height=self.sample_image_size, | ||
357 | latents=stable_latents, | ||
358 | width=self.sample_image_size, | ||
359 | guidance_scale=guidance_scale, | ||
360 | eta=eta, | ||
361 | num_inference_steps=num_inference_steps, | ||
362 | output_type='pil' | ||
363 | )["sample"] | ||
364 | |||
365 | all_samples += samples | ||
366 | del samples | ||
367 | |||
368 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
369 | image_grid.save(f"{samples_path}/{filename}") | ||
370 | |||
371 | del all_samples | ||
372 | del image_grid | ||
373 | del stable_latents | ||
374 | |||
375 | all_samples = [] | ||
376 | filename = f"step_%d.png" % (step) | ||
377 | |||
378 | # Generate and save random samples | ||
379 | for i in range(0, self.random_sample_batches): | ||
380 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | ||
381 | samples = pipeline( | ||
382 | prompt=prompt, | ||
383 | height=self.sample_image_size, | ||
384 | width=self.sample_image_size, | ||
385 | guidance_scale=guidance_scale, | ||
386 | eta=eta, | ||
387 | num_inference_steps=num_inference_steps, | ||
388 | output_type='pil' | ||
389 | )["sample"] | ||
390 | |||
391 | all_samples += samples | ||
392 | del samples | ||
393 | |||
394 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | ||
395 | image_grid.save(f"{samples_path}/{filename}") | ||
396 | |||
397 | del all_samples | ||
398 | del image_grid | ||
399 | |||
400 | del checker | ||
401 | del unwrapped | ||
402 | del pipeline | ||
403 | torch.cuda.empty_cache() | ||
404 | |||
405 | |||
406 | class ImageToLatents(): | ||
407 | def __init__(self, vae): | ||
408 | self.vae = vae | ||
409 | self.encoded_pixel_values_cache = {} | ||
410 | |||
411 | @torch.no_grad() | ||
412 | def __call__(self, batch): | ||
413 | key = "|".join(batch["key"]) | ||
414 | if self.encoded_pixel_values_cache.get(key, None) is None: | ||
415 | self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist | ||
416 | latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215 | ||
417 | return latents | ||
418 | |||
419 | |||
420 | def main(): | ||
421 | args = parse_args() | ||
422 | |||
423 | global_step_offset = 0 | ||
424 | if args.resume_from is not None: | ||
425 | basepath = f"{args.resume_from}" | ||
426 | print("Resuming state from %s" % args.resume_from) | ||
427 | with open(f"{basepath}/resume.json", 'r') as f: | ||
428 | state = json.load(f) | ||
429 | global_step_offset = state["args"].get("global_step", 0) | ||
430 | |||
431 | print("We've trained %d steps so far" % global_step_offset) | ||
432 | else: | ||
433 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
434 | basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" | ||
435 | os.makedirs(basepath, exist_ok=True) | ||
436 | |||
437 | accelerator = Accelerator( | ||
438 | log_with=LoggerType.TENSORBOARD, | ||
439 | logging_dir=f"{basepath}", | ||
440 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
441 | mixed_precision=args.mixed_precision | ||
442 | ) | ||
443 | |||
444 | # If passed along, set the training seed now. | ||
445 | if args.seed is not None: | ||
446 | set_seed(args.seed) | ||
447 | |||
448 | # Load the tokenizer and add the placeholder token as a additional special token | ||
449 | if args.tokenizer_name: | ||
450 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | ||
451 | elif args.pretrained_model_name_or_path: | ||
452 | tokenizer = CLIPTokenizer.from_pretrained( | ||
453 | args.pretrained_model_name_or_path + '/tokenizer' | ||
454 | ) | ||
455 | |||
456 | # Add the placeholder token in tokenizer | ||
457 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
458 | if num_added_tokens == 0: | ||
459 | raise ValueError( | ||
460 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | ||
461 | " `placeholder_token` that is not already in the tokenizer." | ||
462 | ) | ||
463 | |||
464 | # Convert the initializer_token, placeholder_token to ids | ||
465 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
466 | # Check if initializer_token is a single token or a sequence of tokens | ||
467 | if args.vectors_per_token % len(initializer_token_ids) != 0: | ||
468 | raise ValueError( | ||
469 | f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") | ||
470 | |||
471 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
472 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
473 | |||
474 | # Load models and create wrapper for stable diffusion | ||
475 | text_encoder = CLIPTextModel.from_pretrained( | ||
476 | args.pretrained_model_name_or_path + '/text_encoder', | ||
477 | ) | ||
478 | vae = AutoencoderKL.from_pretrained( | ||
479 | args.pretrained_model_name_or_path + '/vae', | ||
480 | ) | ||
481 | unet = UNet2DConditionModel.from_pretrained( | ||
482 | args.pretrained_model_name_or_path + '/unet', | ||
483 | ) | ||
484 | |||
485 | slice_size = unet.config.attention_head_dim // 2 | ||
486 | unet.set_attention_slice(slice_size) | ||
487 | |||
488 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
489 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
490 | |||
491 | # Initialise the newly added placeholder token with the embeddings of the initializer token | ||
492 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
493 | |||
494 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
495 | |||
496 | if args.resume_checkpoint is not None: | ||
497 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | ||
498 | args.placeholder_token] | ||
499 | else: | ||
500 | token_embeds[placeholder_token_id] = initializer_token_embeddings | ||
501 | |||
502 | # Freeze vae and unet | ||
503 | freeze_params(vae.parameters()) | ||
504 | freeze_params(unet.parameters()) | ||
505 | # Freeze all parameters except for the token embeddings in text encoder | ||
506 | params_to_freeze = itertools.chain( | ||
507 | text_encoder.text_model.encoder.parameters(), | ||
508 | text_encoder.text_model.final_layer_norm.parameters(), | ||
509 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
510 | ) | ||
511 | freeze_params(params_to_freeze) | ||
512 | |||
513 | if args.scale_lr: | ||
514 | args.learning_rate = ( | ||
515 | args.learning_rate * args.gradient_accumulation_steps * | ||
516 | args.train_batch_size * accelerator.num_processes | ||
517 | ) | ||
518 | |||
519 | # Initialize the optimizer | ||
520 | optimizer = torch.optim.AdamW( | ||
521 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | ||
522 | lr=args.learning_rate, | ||
523 | betas=(args.adam_beta1, args.adam_beta2), | ||
524 | weight_decay=args.adam_weight_decay, | ||
525 | eps=args.adam_epsilon, | ||
526 | ) | ||
527 | |||
528 | # TODO (patil-suraj): laod scheduler using args | ||
529 | noise_scheduler = DDPMScheduler( | ||
530 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | ||
531 | ) | ||
532 | |||
533 | datamodule = CSVDataModule( | ||
534 | data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, | ||
535 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, | ||
536 | center_crop=args.center_crop) | ||
537 | |||
538 | datamodule.prepare_data() | ||
539 | datamodule.setup() | ||
540 | |||
541 | train_dataloader = datamodule.train_dataloader() | ||
542 | val_dataloader = datamodule.val_dataloader() | ||
543 | |||
544 | checkpointer = Checkpointer( | ||
545 | datamodule=datamodule, | ||
546 | accelerator=accelerator, | ||
547 | vae=vae, | ||
548 | unet=unet, | ||
549 | tokenizer=tokenizer, | ||
550 | placeholder_token=args.placeholder_token, | ||
551 | placeholder_token_id=placeholder_token_id, | ||
552 | output_dir=basepath, | ||
553 | sample_image_size=args.sample_image_size, | ||
554 | sample_batch_size=args.sample_batch_size, | ||
555 | random_sample_batches=args.random_sample_batches, | ||
556 | stable_sample_batches=args.stable_sample_batches, | ||
557 | seed=args.seed | ||
558 | ) | ||
559 | |||
560 | # Scheduler and math around the number of training steps. | ||
561 | overrode_max_train_steps = False | ||
562 | num_update_steps_per_epoch = math.ceil( | ||
563 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
564 | if args.max_train_steps is None: | ||
565 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
566 | overrode_max_train_steps = True | ||
567 | |||
568 | lr_scheduler = get_scheduler( | ||
569 | args.lr_scheduler, | ||
570 | optimizer=optimizer, | ||
571 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
572 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
573 | ) | ||
574 | |||
575 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
576 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
577 | ) | ||
578 | |||
579 | # Move vae and unet to device | ||
580 | vae.to(accelerator.device) | ||
581 | unet.to(accelerator.device) | ||
582 | |||
583 | # Keep vae and unet in eval mode as we don't train these | ||
584 | vae.eval() | ||
585 | unet.eval() | ||
586 | |||
587 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
588 | num_update_steps_per_epoch = math.ceil( | ||
589 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
590 | if overrode_max_train_steps: | ||
591 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
592 | # Afterwards we recalculate our number of training epochs | ||
593 | args.num_train_epochs = math.ceil( | ||
594 | args.max_train_steps / num_update_steps_per_epoch) | ||
595 | |||
596 | # We need to initialize the trackers we use, and also store our configuration. | ||
597 | # The trackers initializes automatically on the main process. | ||
598 | if accelerator.is_main_process: | ||
599 | accelerator.init_trackers("textual_inversion", config=vars(args)) | ||
600 | |||
601 | # Train! | ||
602 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | ||
603 | |||
604 | logger.info("***** Running training *****") | ||
605 | logger.info(f" Num Epochs = {args.num_train_epochs}") | ||
606 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | ||
607 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | ||
608 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | ||
609 | logger.info(f" Total optimization steps = {args.max_train_steps}") | ||
610 | # Only show the progress bar once on each machine. | ||
611 | |||
612 | global_step = 0 | ||
613 | min_val_loss = np.inf | ||
614 | |||
615 | imageToLatents = ImageToLatents(vae) | ||
616 | |||
617 | checkpointer.save_samples( | ||
618 | "validation", | ||
619 | 0, | ||
620 | text_encoder, | ||
621 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
622 | |||
623 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
624 | progress_bar.set_description("Global steps") | ||
625 | |||
626 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | ||
627 | local_progress_bar.set_description("Steps") | ||
628 | |||
629 | try: | ||
630 | for epoch in range(args.num_train_epochs): | ||
631 | local_progress_bar.reset() | ||
632 | |||
633 | text_encoder.train() | ||
634 | train_loss = 0.0 | ||
635 | |||
636 | for step, batch in enumerate(train_dataloader): | ||
637 | with accelerator.accumulate(text_encoder): | ||
638 | with accelerator.autocast(): | ||
639 | # Convert images to latent space | ||
640 | latents = imageToLatents(batch) | ||
641 | |||
642 | # Sample noise that we'll add to the latents | ||
643 | noise = torch.randn(latents.shape).to(latents.device) | ||
644 | bsz = latents.shape[0] | ||
645 | # Sample a random timestep for each image | ||
646 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | ||
647 | (bsz,), device=latents.device).long() | ||
648 | |||
649 | # Add noise to the latents according to the noise magnitude at each timestep | ||
650 | # (this is the forward diffusion process) | ||
651 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
652 | |||
653 | # Get the text embedding for conditioning | ||
654 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
655 | |||
656 | # Predict the noise residual | ||
657 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
658 | |||
659 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
660 | |||
661 | accelerator.backward(loss) | ||
662 | |||
663 | # Zero out the gradients for all token embeddings except the newly added | ||
664 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
665 | if accelerator.num_processes > 1: | ||
666 | grads = text_encoder.module.get_input_embeddings().weight.grad | ||
667 | else: | ||
668 | grads = text_encoder.get_input_embeddings().weight.grad | ||
669 | # Get the index for tokens that we want to zero the grads for | ||
670 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id | ||
671 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) | ||
672 | |||
673 | optimizer.step() | ||
674 | if not accelerator.optimizer_step_was_skipped: | ||
675 | lr_scheduler.step() | ||
676 | optimizer.zero_grad() | ||
677 | |||
678 | loss = loss.detach().item() | ||
679 | train_loss += loss | ||
680 | |||
681 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
682 | if accelerator.sync_gradients: | ||
683 | progress_bar.update(1) | ||
684 | local_progress_bar.update(1) | ||
685 | global_step += 1 | ||
686 | |||
687 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
688 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | ||
689 | save_resume_file(basepath, args, { | ||
690 | "global_step": global_step + global_step_offset, | ||
691 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
692 | }) | ||
693 | checkpointer.save_samples( | ||
694 | "training", | ||
695 | global_step + global_step_offset, | ||
696 | text_encoder, | ||
697 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
698 | |||
699 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | ||
700 | local_progress_bar.set_postfix(**logs) | ||
701 | |||
702 | if global_step >= args.max_train_steps: | ||
703 | break | ||
704 | |||
705 | train_loss /= len(train_dataloader) | ||
706 | |||
707 | text_encoder.eval() | ||
708 | val_loss = 0.0 | ||
709 | |||
710 | for step, batch in enumerate(val_dataloader): | ||
711 | with torch.no_grad(), accelerator.autocast(): | ||
712 | latents = imageToLatents(batch) | ||
713 | |||
714 | noise = torch.randn(latents.shape).to(latents.device) | ||
715 | bsz = latents.shape[0] | ||
716 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | ||
717 | (bsz,), device=latents.device).long() | ||
718 | |||
719 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
720 | |||
721 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
722 | |||
723 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
724 | |||
725 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | ||
726 | |||
727 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
728 | |||
729 | loss = loss.detach().item() | ||
730 | val_loss += loss | ||
731 | |||
732 | if accelerator.sync_gradients: | ||
733 | progress_bar.update(1) | ||
734 | local_progress_bar.update(1) | ||
735 | |||
736 | logs = {"mode": "validation", "loss": loss} | ||
737 | local_progress_bar.set_postfix(**logs) | ||
738 | |||
739 | val_loss /= len(val_dataloader) | ||
740 | |||
741 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | ||
742 | |||
743 | if min_val_loss > val_loss: | ||
744 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
745 | min_val_loss = val_loss | ||
746 | |||
747 | checkpointer.save_samples( | ||
748 | "validation", | ||
749 | global_step + global_step_offset, | ||
750 | text_encoder, | ||
751 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
752 | |||
753 | accelerator.wait_for_everyone() | ||
754 | |||
755 | # Create the pipeline using using the trained modules and save it. | ||
756 | if accelerator.is_main_process: | ||
757 | print("Finished! Saving final checkpoint and resume state.") | ||
758 | checkpointer.checkpoint( | ||
759 | global_step + global_step_offset, | ||
760 | text_encoder, | ||
761 | path=f"{basepath}/learned_embeds.bin" | ||
762 | ) | ||
763 | |||
764 | save_resume_file(basepath, args, { | ||
765 | "global_step": global_step + global_step_offset, | ||
766 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
767 | }) | ||
768 | |||
769 | accelerator.end_training() | ||
770 | |||
771 | except KeyboardInterrupt: | ||
772 | if accelerator.is_main_process: | ||
773 | print("Interrupted, saving checkpoint and resume state...") | ||
774 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | ||
775 | save_resume_file(basepath, args, { | ||
776 | "global_step": global_step + global_step_offset, | ||
777 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
778 | }) | ||
779 | accelerator.end_training() | ||
780 | quit() | ||
781 | |||
782 | |||
783 | if __name__ == "__main__": | ||
784 | main() | ||
diff --git a/pipelines/stable_diffusion/no_check.py b/pipelines/stable_diffusion/no_check.py new file mode 100644 index 0000000..06c2f72 --- /dev/null +++ b/pipelines/stable_diffusion/no_check.py | |||
@@ -0,0 +1,13 @@ | |||
1 | from diffusers import ModelMixin | ||
2 | import torch | ||
3 | |||
4 | |||
5 | class NoCheck(ModelMixin): | ||
6 | """Can be used in place of safety checker. Use responsibly and at your own risk.""" | ||
7 | |||
8 | def __init__(self): | ||
9 | super().__init__() | ||
10 | self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3))) | ||
11 | |||
12 | def forward(self, images=None, **kwargs): | ||
13 | return images, [False] | ||
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py new file mode 100644 index 0000000..ee7fc33 --- /dev/null +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py | |||
@@ -0,0 +1,690 @@ | |||
1 | # coding=utf-8 | ||
2 | # Copyright 2022 The HuggingFace Inc. team. | ||
3 | # | ||
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
5 | # you may not use this file except in compliance with the License. | ||
6 | # You may obtain a copy of the License at | ||
7 | # | ||
8 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
9 | # | ||
10 | # Unless required by applicable law or agreed to in writing, software | ||
11 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
13 | # See the License for the specific language governing permissions and | ||
14 | # limitations under the License. | ||
15 | """ Conversion script for the LDM checkpoints. """ | ||
16 | |||
17 | import argparse | ||
18 | import os | ||
19 | |||
20 | import torch | ||
21 | |||
22 | |||
23 | try: | ||
24 | from omegaconf import OmegaConf | ||
25 | except ImportError: | ||
26 | raise ImportError( | ||
27 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." | ||
28 | ) | ||
29 | |||
30 | from diffusers import ( | ||
31 | AutoencoderKL, | ||
32 | DDIMScheduler, | ||
33 | LDMTextToImagePipeline, | ||
34 | LMSDiscreteScheduler, | ||
35 | PNDMScheduler, | ||
36 | StableDiffusionPipeline, | ||
37 | UNet2DConditionModel, | ||
38 | ) | ||
39 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel | ||
40 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
41 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer | ||
42 | |||
43 | |||
44 | def shave_segments(path, n_shave_prefix_segments=1): | ||
45 | """ | ||
46 | Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
47 | """ | ||
48 | if n_shave_prefix_segments >= 0: | ||
49 | return ".".join(path.split(".")[n_shave_prefix_segments:]) | ||
50 | else: | ||
51 | return ".".join(path.split(".")[:n_shave_prefix_segments]) | ||
52 | |||
53 | |||
54 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
55 | """ | ||
56 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
57 | """ | ||
58 | mapping = [] | ||
59 | for old_item in old_list: | ||
60 | new_item = old_item.replace("in_layers.0", "norm1") | ||
61 | new_item = new_item.replace("in_layers.2", "conv1") | ||
62 | |||
63 | new_item = new_item.replace("out_layers.0", "norm2") | ||
64 | new_item = new_item.replace("out_layers.3", "conv2") | ||
65 | |||
66 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") | ||
67 | new_item = new_item.replace("skip_connection", "conv_shortcut") | ||
68 | |||
69 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
70 | |||
71 | mapping.append({"old": old_item, "new": new_item}) | ||
72 | |||
73 | return mapping | ||
74 | |||
75 | |||
76 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
77 | """ | ||
78 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
79 | """ | ||
80 | mapping = [] | ||
81 | for old_item in old_list: | ||
82 | new_item = old_item | ||
83 | |||
84 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") | ||
85 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
86 | |||
87 | mapping.append({"old": old_item, "new": new_item}) | ||
88 | |||
89 | return mapping | ||
90 | |||
91 | |||
92 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): | ||
93 | """ | ||
94 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
95 | """ | ||
96 | mapping = [] | ||
97 | for old_item in old_list: | ||
98 | new_item = old_item | ||
99 | |||
100 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') | ||
101 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') | ||
102 | |||
103 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | ||
104 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | ||
105 | |||
106 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
107 | |||
108 | mapping.append({"old": old_item, "new": new_item}) | ||
109 | |||
110 | return mapping | ||
111 | |||
112 | |||
113 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | ||
114 | """ | ||
115 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
116 | """ | ||
117 | mapping = [] | ||
118 | for old_item in old_list: | ||
119 | new_item = old_item | ||
120 | |||
121 | new_item = new_item.replace("norm.weight", "group_norm.weight") | ||
122 | new_item = new_item.replace("norm.bias", "group_norm.bias") | ||
123 | |||
124 | new_item = new_item.replace("q.weight", "query.weight") | ||
125 | new_item = new_item.replace("q.bias", "query.bias") | ||
126 | |||
127 | new_item = new_item.replace("k.weight", "key.weight") | ||
128 | new_item = new_item.replace("k.bias", "key.bias") | ||
129 | |||
130 | new_item = new_item.replace("v.weight", "value.weight") | ||
131 | new_item = new_item.replace("v.bias", "value.bias") | ||
132 | |||
133 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") | ||
134 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") | ||
135 | |||
136 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
137 | |||
138 | mapping.append({"old": old_item, "new": new_item}) | ||
139 | |||
140 | return mapping | ||
141 | |||
142 | |||
143 | def assign_to_checkpoint( | ||
144 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | ||
145 | ): | ||
146 | """ | ||
147 | This does the final conversion step: take locally converted weights and apply a global renaming | ||
148 | to them. It splits attention layers, and takes into account additional replacements | ||
149 | that may arise. | ||
150 | |||
151 | Assigns the weights to the new checkpoint. | ||
152 | """ | ||
153 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | ||
154 | |||
155 | # Splits the attention layers into three variables. | ||
156 | if attention_paths_to_split is not None: | ||
157 | for path, path_map in attention_paths_to_split.items(): | ||
158 | old_tensor = old_checkpoint[path] | ||
159 | channels = old_tensor.shape[0] // 3 | ||
160 | |||
161 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | ||
162 | |||
163 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | ||
164 | |||
165 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | ||
166 | query, key, value = old_tensor.split(channels // num_heads, dim=1) | ||
167 | |||
168 | checkpoint[path_map["query"]] = query.reshape(target_shape) | ||
169 | checkpoint[path_map["key"]] = key.reshape(target_shape) | ||
170 | checkpoint[path_map["value"]] = value.reshape(target_shape) | ||
171 | |||
172 | for path in paths: | ||
173 | new_path = path["new"] | ||
174 | |||
175 | # These have already been assigned | ||
176 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: | ||
177 | continue | ||
178 | |||
179 | # Global renaming happens here | ||
180 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | ||
181 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | ||
182 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | ||
183 | |||
184 | if additional_replacements is not None: | ||
185 | for replacement in additional_replacements: | ||
186 | new_path = new_path.replace(replacement["old"], replacement["new"]) | ||
187 | |||
188 | # proj_attn.weight has to be converted from conv 1D to linear | ||
189 | if "proj_attn.weight" in new_path: | ||
190 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | ||
191 | else: | ||
192 | checkpoint[new_path] = old_checkpoint[path["old"]] | ||
193 | |||
194 | |||
195 | def conv_attn_to_linear(checkpoint): | ||
196 | keys = list(checkpoint.keys()) | ||
197 | attn_keys = ["query.weight", "key.weight", "value.weight"] | ||
198 | for key in keys: | ||
199 | if ".".join(key.split(".")[-2:]) in attn_keys: | ||
200 | if checkpoint[key].ndim > 2: | ||
201 | checkpoint[key] = checkpoint[key][:, :, 0, 0] | ||
202 | elif "proj_attn.weight" in key: | ||
203 | if checkpoint[key].ndim > 2: | ||
204 | checkpoint[key] = checkpoint[key][:, :, 0] | ||
205 | |||
206 | |||
207 | def create_unet_diffusers_config(original_config): | ||
208 | """ | ||
209 | Creates a config for the diffusers based on the config of the LDM model. | ||
210 | """ | ||
211 | unet_params = original_config.model.params.unet_config.params | ||
212 | |||
213 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | ||
214 | |||
215 | down_block_types = [] | ||
216 | resolution = 1 | ||
217 | for i in range(len(block_out_channels)): | ||
218 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | ||
219 | down_block_types.append(block_type) | ||
220 | if i != len(block_out_channels) - 1: | ||
221 | resolution *= 2 | ||
222 | |||
223 | up_block_types = [] | ||
224 | for i in range(len(block_out_channels)): | ||
225 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | ||
226 | up_block_types.append(block_type) | ||
227 | resolution //= 2 | ||
228 | |||
229 | config = dict( | ||
230 | sample_size=unet_params.image_size, | ||
231 | in_channels=unet_params.in_channels, | ||
232 | out_channels=unet_params.out_channels, | ||
233 | down_block_types=tuple(down_block_types), | ||
234 | up_block_types=tuple(up_block_types), | ||
235 | block_out_channels=tuple(block_out_channels), | ||
236 | layers_per_block=unet_params.num_res_blocks, | ||
237 | cross_attention_dim=unet_params.context_dim, | ||
238 | attention_head_dim=unet_params.num_heads, | ||
239 | ) | ||
240 | |||
241 | return config | ||
242 | |||
243 | |||
244 | def create_vae_diffusers_config(original_config): | ||
245 | """ | ||
246 | Creates a config for the diffusers based on the config of the LDM model. | ||
247 | """ | ||
248 | vae_params = original_config.model.params.first_stage_config.params.ddconfig | ||
249 | _ = original_config.model.params.first_stage_config.params.embed_dim | ||
250 | |||
251 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | ||
252 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | ||
253 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | ||
254 | |||
255 | config = dict( | ||
256 | sample_size=vae_params.resolution, | ||
257 | in_channels=vae_params.in_channels, | ||
258 | out_channels=vae_params.out_ch, | ||
259 | down_block_types=tuple(down_block_types), | ||
260 | up_block_types=tuple(up_block_types), | ||
261 | block_out_channels=tuple(block_out_channels), | ||
262 | latent_channels=vae_params.z_channels, | ||
263 | layers_per_block=vae_params.num_res_blocks, | ||
264 | ) | ||
265 | return config | ||
266 | |||
267 | |||
268 | def create_diffusers_schedular(original_config): | ||
269 | schedular = DDIMScheduler( | ||
270 | num_train_timesteps=original_config.model.params.timesteps, | ||
271 | beta_start=original_config.model.params.linear_start, | ||
272 | beta_end=original_config.model.params.linear_end, | ||
273 | beta_schedule="scaled_linear", | ||
274 | ) | ||
275 | return schedular | ||
276 | |||
277 | |||
278 | def create_ldm_bert_config(original_config): | ||
279 | bert_params = original_config.model.parms.cond_stage_config.params | ||
280 | config = LDMBertConfig( | ||
281 | d_model=bert_params.n_embed, | ||
282 | encoder_layers=bert_params.n_layer, | ||
283 | encoder_ffn_dim=bert_params.n_embed * 4, | ||
284 | ) | ||
285 | return config | ||
286 | |||
287 | |||
288 | def convert_ldm_unet_checkpoint(checkpoint, config): | ||
289 | """ | ||
290 | Takes a state dict and a config, and returns a converted checkpoint. | ||
291 | """ | ||
292 | |||
293 | # extract state_dict for UNet | ||
294 | unet_state_dict = {} | ||
295 | unet_key = "model.diffusion_model." | ||
296 | keys = list(checkpoint.keys()) | ||
297 | for key in keys: | ||
298 | if key.startswith(unet_key): | ||
299 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
300 | |||
301 | new_checkpoint = {} | ||
302 | |||
303 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
304 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
305 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
306 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
307 | |||
308 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
309 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
310 | |||
311 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
312 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
313 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
314 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
315 | |||
316 | # Retrieves the keys for the input blocks only | ||
317 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | ||
318 | input_blocks = { | ||
319 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | ||
320 | for layer_id in range(num_input_blocks) | ||
321 | } | ||
322 | |||
323 | # Retrieves the keys for the middle blocks only | ||
324 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | ||
325 | middle_blocks = { | ||
326 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | ||
327 | for layer_id in range(num_middle_blocks) | ||
328 | } | ||
329 | |||
330 | # Retrieves the keys for the output blocks only | ||
331 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | ||
332 | output_blocks = { | ||
333 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | ||
334 | for layer_id in range(num_output_blocks) | ||
335 | } | ||
336 | |||
337 | for i in range(1, num_input_blocks): | ||
338 | block_id = (i - 1) // (config["layers_per_block"] + 1) | ||
339 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | ||
340 | |||
341 | resnets = [ | ||
342 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | ||
343 | ] | ||
344 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | ||
345 | |||
346 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | ||
347 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | ||
348 | f"input_blocks.{i}.0.op.weight" | ||
349 | ) | ||
350 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | ||
351 | f"input_blocks.{i}.0.op.bias" | ||
352 | ) | ||
353 | |||
354 | paths = renew_resnet_paths(resnets) | ||
355 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
356 | assign_to_checkpoint( | ||
357 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
358 | ) | ||
359 | |||
360 | if len(attentions): | ||
361 | paths = renew_attention_paths(attentions) | ||
362 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | ||
363 | assign_to_checkpoint( | ||
364 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
365 | ) | ||
366 | |||
367 | resnet_0 = middle_blocks[0] | ||
368 | attentions = middle_blocks[1] | ||
369 | resnet_1 = middle_blocks[2] | ||
370 | |||
371 | resnet_0_paths = renew_resnet_paths(resnet_0) | ||
372 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | ||
373 | |||
374 | resnet_1_paths = renew_resnet_paths(resnet_1) | ||
375 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | ||
376 | |||
377 | attentions_paths = renew_attention_paths(attentions) | ||
378 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | ||
379 | assign_to_checkpoint( | ||
380 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
381 | ) | ||
382 | |||
383 | for i in range(num_output_blocks): | ||
384 | block_id = i // (config["layers_per_block"] + 1) | ||
385 | layer_in_block_id = i % (config["layers_per_block"] + 1) | ||
386 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | ||
387 | output_block_list = {} | ||
388 | |||
389 | for layer in output_block_layers: | ||
390 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | ||
391 | if layer_id in output_block_list: | ||
392 | output_block_list[layer_id].append(layer_name) | ||
393 | else: | ||
394 | output_block_list[layer_id] = [layer_name] | ||
395 | |||
396 | if len(output_block_list) > 1: | ||
397 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | ||
398 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | ||
399 | |||
400 | resnet_0_paths = renew_resnet_paths(resnets) | ||
401 | paths = renew_resnet_paths(resnets) | ||
402 | |||
403 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
404 | assign_to_checkpoint( | ||
405 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
406 | ) | ||
407 | |||
408 | if ["conv.weight", "conv.bias"] in output_block_list.values(): | ||
409 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) | ||
410 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | ||
411 | f"output_blocks.{i}.{index}.conv.weight" | ||
412 | ] | ||
413 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | ||
414 | f"output_blocks.{i}.{index}.conv.bias" | ||
415 | ] | ||
416 | |||
417 | # Clear attentions as they have been attributed above. | ||
418 | if len(attentions) == 2: | ||
419 | attentions = [] | ||
420 | |||
421 | if len(attentions): | ||
422 | paths = renew_attention_paths(attentions) | ||
423 | meta_path = { | ||
424 | "old": f"output_blocks.{i}.1", | ||
425 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | ||
426 | } | ||
427 | assign_to_checkpoint( | ||
428 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
429 | ) | ||
430 | else: | ||
431 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | ||
432 | for path in resnet_0_paths: | ||
433 | old_path = ".".join(["output_blocks", str(i), path["old"]]) | ||
434 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | ||
435 | |||
436 | new_checkpoint[new_path] = unet_state_dict[old_path] | ||
437 | |||
438 | return new_checkpoint | ||
439 | |||
440 | |||
441 | def convert_ldm_vae_checkpoint(checkpoint, config): | ||
442 | # extract state dict for VAE | ||
443 | vae_state_dict = {} | ||
444 | vae_key = "first_stage_model." | ||
445 | keys = list(checkpoint.keys()) | ||
446 | for key in keys: | ||
447 | if key.startswith(vae_key): | ||
448 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
449 | |||
450 | new_checkpoint = {} | ||
451 | |||
452 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
453 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
454 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
455 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
456 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
457 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
458 | |||
459 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
460 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
461 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
462 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
463 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
464 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
465 | |||
466 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
467 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
468 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
469 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
470 | |||
471 | # Retrieves the keys for the encoder down blocks only | ||
472 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | ||
473 | down_blocks = { | ||
474 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | ||
475 | } | ||
476 | |||
477 | # Retrieves the keys for the decoder up blocks only | ||
478 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | ||
479 | up_blocks = { | ||
480 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | ||
481 | } | ||
482 | |||
483 | for i in range(num_down_blocks): | ||
484 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | ||
485 | |||
486 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | ||
487 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | ||
488 | f"encoder.down.{i}.downsample.conv.weight" | ||
489 | ) | ||
490 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | ||
491 | f"encoder.down.{i}.downsample.conv.bias" | ||
492 | ) | ||
493 | |||
494 | paths = renew_vae_resnet_paths(resnets) | ||
495 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | ||
496 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
497 | |||
498 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | ||
499 | num_mid_res_blocks = 2 | ||
500 | for i in range(1, num_mid_res_blocks + 1): | ||
501 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | ||
502 | |||
503 | paths = renew_vae_resnet_paths(resnets) | ||
504 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
505 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
506 | |||
507 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | ||
508 | paths = renew_vae_attention_paths(mid_attentions) | ||
509 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
510 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
511 | conv_attn_to_linear(new_checkpoint) | ||
512 | |||
513 | for i in range(num_up_blocks): | ||
514 | block_id = num_up_blocks - 1 - i | ||
515 | resnets = [ | ||
516 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | ||
517 | ] | ||
518 | |||
519 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | ||
520 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | ||
521 | f"decoder.up.{block_id}.upsample.conv.weight" | ||
522 | ] | ||
523 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | ||
524 | f"decoder.up.{block_id}.upsample.conv.bias" | ||
525 | ] | ||
526 | |||
527 | paths = renew_vae_resnet_paths(resnets) | ||
528 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | ||
529 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
530 | |||
531 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | ||
532 | num_mid_res_blocks = 2 | ||
533 | for i in range(1, num_mid_res_blocks + 1): | ||
534 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | ||
535 | |||
536 | paths = renew_vae_resnet_paths(resnets) | ||
537 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
538 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
539 | |||
540 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | ||
541 | paths = renew_vae_attention_paths(mid_attentions) | ||
542 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
543 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
544 | conv_attn_to_linear(new_checkpoint) | ||
545 | return new_checkpoint | ||
546 | |||
547 | |||
548 | def convert_ldm_bert_checkpoint(checkpoint, config): | ||
549 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
550 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight | ||
551 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight | ||
552 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight | ||
553 | |||
554 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight | ||
555 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias | ||
556 | |||
557 | def _copy_linear(hf_linear, pt_linear): | ||
558 | hf_linear.weight = pt_linear.weight | ||
559 | hf_linear.bias = pt_linear.bias | ||
560 | |||
561 | def _copy_layer(hf_layer, pt_layer): | ||
562 | # copy layer norms | ||
563 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) | ||
564 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) | ||
565 | |||
566 | # copy attn | ||
567 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) | ||
568 | |||
569 | # copy MLP | ||
570 | pt_mlp = pt_layer[1][1] | ||
571 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) | ||
572 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) | ||
573 | |||
574 | def _copy_layers(hf_layers, pt_layers): | ||
575 | for i, hf_layer in enumerate(hf_layers): | ||
576 | if i != 0: | ||
577 | i += i | ||
578 | pt_layer = pt_layers[i : i + 2] | ||
579 | _copy_layer(hf_layer, pt_layer) | ||
580 | |||
581 | hf_model = LDMBertModel(config).eval() | ||
582 | |||
583 | # copy embeds | ||
584 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight | ||
585 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight | ||
586 | |||
587 | # copy layer norm | ||
588 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) | ||
589 | |||
590 | # copy hidden layers | ||
591 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) | ||
592 | |||
593 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) | ||
594 | |||
595 | return hf_model | ||
596 | |||
597 | |||
598 | if __name__ == "__main__": | ||
599 | parser = argparse.ArgumentParser() | ||
600 | |||
601 | parser.add_argument( | ||
602 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." | ||
603 | ) | ||
604 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | ||
605 | parser.add_argument( | ||
606 | "--original_config_file", | ||
607 | default=None, | ||
608 | type=str, | ||
609 | help="The YAML config file corresponding to the original architecture.", | ||
610 | ) | ||
611 | parser.add_argument( | ||
612 | "--scheduler_type", | ||
613 | default="pndm", | ||
614 | type=str, | ||
615 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", | ||
616 | ) | ||
617 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | ||
618 | |||
619 | args = parser.parse_args() | ||
620 | |||
621 | if args.original_config_file is None: | ||
622 | os.system( | ||
623 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" | ||
624 | ) | ||
625 | args.original_config_file = "./v1-inference.yaml" | ||
626 | |||
627 | original_config = OmegaConf.load(args.original_config_file) | ||
628 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] | ||
629 | |||
630 | num_train_timesteps = original_config.model.params.timesteps | ||
631 | beta_start = original_config.model.params.linear_start | ||
632 | beta_end = original_config.model.params.linear_end | ||
633 | if args.scheduler_type == "pndm": | ||
634 | scheduler = PNDMScheduler( | ||
635 | beta_end=beta_end, | ||
636 | beta_schedule="scaled_linear", | ||
637 | beta_start=beta_start, | ||
638 | num_train_timesteps=num_train_timesteps, | ||
639 | skip_prk_steps=True, | ||
640 | ) | ||
641 | elif args.scheduler_type == "lms": | ||
642 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") | ||
643 | elif args.scheduler_type == "ddim": | ||
644 | scheduler = DDIMScheduler( | ||
645 | beta_start=beta_start, | ||
646 | beta_end=beta_end, | ||
647 | beta_schedule="scaled_linear", | ||
648 | clip_sample=False, | ||
649 | set_alpha_to_one=False, | ||
650 | ) | ||
651 | else: | ||
652 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") | ||
653 | |||
654 | # Convert the UNet2DConditionModel model. | ||
655 | unet_config = create_unet_diffusers_config(original_config) | ||
656 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) | ||
657 | |||
658 | unet = UNet2DConditionModel(**unet_config) | ||
659 | unet.load_state_dict(converted_unet_checkpoint) | ||
660 | |||
661 | # Convert the VAE model. | ||
662 | vae_config = create_vae_diffusers_config(original_config) | ||
663 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) | ||
664 | |||
665 | vae = AutoencoderKL(**vae_config) | ||
666 | vae.load_state_dict(converted_vae_checkpoint) | ||
667 | |||
668 | # Convert the text model. | ||
669 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
670 | if text_model_type == "FrozenCLIPEmbedder": | ||
671 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | ||
672 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
673 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
674 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
675 | pipe = StableDiffusionPipeline( | ||
676 | vae=vae, | ||
677 | text_encoder=text_model, | ||
678 | tokenizer=tokenizer, | ||
679 | unet=unet, | ||
680 | scheduler=scheduler, | ||
681 | safety_checker=safety_checker, | ||
682 | feature_extractor=feature_extractor, | ||
683 | ) | ||
684 | else: | ||
685 | text_config = create_ldm_bert_config(original_config) | ||
686 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) | ||
687 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | ||
688 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
689 | |||
690 | pipe.save_pretrained(args.dump_path) | ||
diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..75158a0 --- /dev/null +++ b/setup.py | |||
@@ -0,0 +1,13 @@ | |||
1 | from setuptools import setup, find_packages | ||
2 | |||
3 | setup( | ||
4 | name='textual-inversion-diff', | ||
5 | version='0.0.1', | ||
6 | description='', | ||
7 | packages=find_packages(), | ||
8 | install_requires=[ | ||
9 | 'torch', | ||
10 | 'numpy', | ||
11 | 'tqdm', | ||
12 | ], | ||
13 | ) | ||