summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore164
-rw-r--r--.pep82
-rw-r--r--data.py145
-rw-r--r--environment.yaml36
-rw-r--r--main.py784
-rw-r--r--pipelines/stable_diffusion/no_check.py13
-rw-r--r--scripts/convert_original_stable_diffusion_to_diffusers.py690
-rw-r--r--setup.py13
8 files changed, 1847 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..00e7681
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,164 @@
1# Byte-compiled / optimized / DLL files
2__pycache__/
3*.py[cod]
4*$py.class
5
6# C extensions
7*.so
8
9# Distribution / packaging
10.Python
11build/
12develop-eggs/
13dist/
14downloads/
15eggs/
16.eggs/
17lib/
18lib64/
19parts/
20sdist/
21var/
22wheels/
23share/python-wheels/
24*.egg-info/
25.installed.cfg
26*.egg
27MANIFEST
28
29# PyInstaller
30# Usually these files are written by a python script from a template
31# before PyInstaller builds the exe, so as to inject date/other infos into it.
32*.manifest
33*.spec
34
35# Installer logs
36pip-log.txt
37pip-delete-this-directory.txt
38
39# Unit test / coverage reports
40htmlcov/
41.tox/
42.nox/
43.coverage
44.coverage.*
45.cache
46nosetests.xml
47coverage.xml
48*.cover
49*.py,cover
50.hypothesis/
51.pytest_cache/
52cover/
53
54# Translations
55*.mo
56*.pot
57
58# Django stuff:
59*.log
60local_settings.py
61db.sqlite3
62db.sqlite3-journal
63
64# Flask stuff:
65instance/
66.webassets-cache
67
68# Scrapy stuff:
69.scrapy
70
71# Sphinx documentation
72docs/_build/
73
74# PyBuilder
75.pybuilder/
76target/
77
78# Jupyter Notebook
79.ipynb_checkpoints
80
81# IPython
82profile_default/
83ipython_config.py
84
85# pyenv
86# For a library or package, you might want to ignore these files since the code is
87# intended to run in multiple environments; otherwise, check them in:
88# .python-version
89
90# pipenv
91# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92# However, in case of collaboration, if having platform-specific dependencies or dependencies
93# having no cross-platform support, pipenv may install dependencies that don't work, or not
94# install all needed dependencies.
95#Pipfile.lock
96
97# poetry
98# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99# This is especially recommended for binary packages to ensure reproducibility, and is more
100# commonly ignored for libraries.
101# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102#poetry.lock
103
104# pdm
105# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106#pdm.lock
107# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108# in version control.
109# https://pdm.fming.dev/#use-with-ide
110.pdm.toml
111
112# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113__pypackages__/
114
115# Celery stuff
116celerybeat-schedule
117celerybeat.pid
118
119# SageMath parsed files
120*.sage.py
121
122# Environments
123.env
124.venv
125env/
126venv/
127ENV/
128env.bak/
129venv.bak/
130
131# Spyder project settings
132.spyderproject
133.spyproject
134
135# Rope project settings
136.ropeproject
137
138# mkdocs documentation
139/site
140
141# mypy
142.mypy_cache/
143.dmypy.json
144dmypy.json
145
146# Pyre type checker
147.pyre/
148
149# pytype static type analyzer
150.pytype/
151
152# Cython debug symbols
153cython_debug/
154
155# PyCharm
156# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158# and can be added to the global gitignore or merged into this file. For a more nuclear
159# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160#.idea/
161
162text-inversion-model/
163conf.json
164v1-inference.yaml
diff --git a/.pep8 b/.pep8
new file mode 100644
index 0000000..9d54e0f
--- /dev/null
+++ b/.pep8
@@ -0,0 +1,2 @@
1[pycodestyle]
2max_line_length = 120
diff --git a/data.py b/data.py
new file mode 100644
index 0000000..0d1e96e
--- /dev/null
+++ b/data.py
@@ -0,0 +1,145 @@
1import os
2import numpy as np
3import pandas as pd
4import random
5import PIL
6import pytorch_lightning as pl
7from PIL import Image
8import torch
9from torch.utils.data import Dataset, DataLoader, random_split
10from torchvision import transforms
11
12
13class CSVDataModule(pl.LightningDataModule):
14 def __init__(self,
15 batch_size,
16 data_root,
17 tokenizer,
18 size=512,
19 repeats=100,
20 interpolation="bicubic",
21 placeholder_token="*",
22 flip_p=0.5,
23 center_crop=False):
24 super().__init__()
25
26 self.data_root = data_root
27 self.tokenizer = tokenizer
28 self.size = size
29 self.repeats = repeats
30 self.placeholder_token = placeholder_token
31 self.center_crop = center_crop
32 self.flip_p = flip_p
33 self.interpolation = interpolation
34
35 self.batch_size = batch_size
36
37 def prepare_data(self):
38 metadata = pd.read_csv(f'{self.data_root}/list.csv')
39 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
40 captions = [caption for caption in metadata['caption'].values]
41 skips = [skip for skip in metadata['skip'].values]
42 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
43
44 def setup(self, stage=None):
45 train_set_size = int(len(self.data_full) * 0.8)
46 valid_set_size = len(self.data_full) - train_set_size
47 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
48
49 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
50 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop)
51 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
52 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop)
53 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
54 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size)
55
56 def train_dataloader(self):
57 return self.train_dataloader_
58
59 def val_dataloader(self):
60 return self.val_dataloader_
61
62
63class CSVDataset(Dataset):
64 def __init__(self,
65 data,
66 tokenizer,
67 size=512,
68 repeats=1,
69 interpolation="bicubic",
70 flip_p=0.5,
71 placeholder_token="*",
72 center_crop=False,
73 ):
74
75 self.data = data
76 self.tokenizer = tokenizer
77
78 self.num_images = len(self.data)
79 self._length = self.num_images * repeats
80
81 self.placeholder_token = placeholder_token
82
83 self.size = size
84 self.center_crop = center_crop
85 self.interpolation = {"linear": PIL.Image.LINEAR,
86 "bilinear": PIL.Image.BILINEAR,
87 "bicubic": PIL.Image.BICUBIC,
88 "lanczos": PIL.Image.LANCZOS,
89 }[interpolation]
90 self.flip = transforms.RandomHorizontalFlip(p=flip_p)
91
92 self.cache = {}
93
94 def __len__(self):
95 return self._length
96
97 def get_example(self, i, flipped):
98 image_path, text = self.data[i % self.num_images]
99
100 if image_path in self.cache:
101 return self.cache[image_path]
102
103 example = {}
104 image = Image.open(image_path)
105
106 if not image.mode == "RGB":
107 image = image.convert("RGB")
108
109 text = text.format(self.placeholder_token)
110
111 example["prompt"] = text
112 example["input_ids"] = self.tokenizer(
113 text,
114 padding="max_length",
115 truncation=True,
116 max_length=self.tokenizer.model_max_length,
117 return_tensors="pt",
118 ).input_ids[0]
119
120 # default to score-sde preprocessing
121 img = np.array(image).astype(np.uint8)
122
123 if self.center_crop:
124 crop = min(img.shape[0], img.shape[1])
125 h, w, = img.shape[0], img.shape[1]
126 img = img[(h - crop) // 2:(h + crop) // 2,
127 (w - crop) // 2:(w + crop) // 2]
128
129 image = Image.fromarray(img)
130 image = image.resize((self.size, self.size),
131 resample=self.interpolation)
132 image = self.flip(image)
133 image = np.array(image).astype(np.uint8)
134 image = (image / 127.5 - 1.0).astype(np.float32)
135
136 example["key"] = "-".join([image_path, "-", str(flipped)])
137 example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
138
139 self.cache[image_path] = example
140 return example
141
142 def __getitem__(self, i):
143 flipped = random.choice([False, True])
144 example = self.get_example(i, flipped)
145 return example
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000..a460158
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,36 @@
1name: ldd
2channels:
3 - pytorch
4 - defaults
5dependencies:
6 - cudatoolkit=11.3
7 - numpy=1.22.3
8 - pip=20.3
9 - python=3.8.10
10 - pytorch=1.12.1
11 - torchvision=0.13.1
12 - pandas=1.4.3
13 - pip:
14 - -e .
15 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
16 - -e git+https://github.com/openai/CLIP.git@main#egg=clip
17 - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
18 - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion
19 - accelerate==0.12.0
20 - albumentations==1.1.0
21 - diffusers==0.3.0
22 - einops==0.4.1
23 - imageio-ffmpeg==0.4.7
24 - imageio==2.14.1
25 - kornia==0.6
26 - pudb==2019.2
27 - omegaconf==2.1.1
28 - opencv-python-headless==4.6.0.66
29 - python-slugify>=6.1.2
30 - pytorch-lightning==1.7.7
31 - setuptools==59.5.0
32 - streamlit>=0.73.1
33 - test-tube>=0.7.5
34 - torch-fidelity==0.3.0
35 - torchmetrics==0.9.3
36 - transformers==4.19.2
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..9bf65a5
--- /dev/null
+++ b/main.py
@@ -0,0 +1,784 @@
1import argparse
2import itertools
3import math
4import os
5import random
6import datetime
7from pathlib import Path
8from typing import Optional
9
10import numpy as np
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14import torch.utils.checkpoint
15from torch.utils.data import Dataset
16
17import PIL
18from accelerate import Accelerator
19from accelerate.logging import get_logger
20from accelerate.utils import LoggerType, set_seed
21from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
22from diffusers.optimization import get_scheduler
23from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
24from einops import rearrange
25from pipelines.stable_diffusion.no_check import NoCheck
26from huggingface_hub import HfFolder, Repository, whoami
27from PIL import Image
28from tqdm.auto import tqdm
29from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
30from slugify import slugify
31import json
32import os
33import sys
34
35from data import CSVDataModule
36
37logger = get_logger(__name__)
38
39
40def parse_args():
41 parser = argparse.ArgumentParser(
42 description="Simple example of a training script.")
43 parser.add_argument(
44 "--pretrained_model_name_or_path",
45 type=str,
46 default=None,
47 help="Path to pretrained model or model identifier from huggingface.co/models.",
48 )
49 parser.add_argument(
50 "--tokenizer_name",
51 type=str,
52 default=None,
53 help="Pretrained tokenizer name or path if not the same as model_name",
54 )
55 parser.add_argument(
56 "--train_data_dir", type=str, default=None, help="A folder containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token", type=str, default=None, help="A token to use as initializer word."
66 )
67 parser.add_argument(
68 "--vectors_per_token", type=int, default=1, help="Vectors per token."
69 )
70 parser.add_argument("--repeats", type=int, default=100,
71 help="How many times to repeat the training data.")
72 parser.add_argument(
73 "--output_dir",
74 type=str,
75 default="text-inversion-model",
76 help="The output directory where the model predictions and checkpoints will be written.",
77 )
78 parser.add_argument("--seed", type=int, default=None,
79 help="A seed for reproducible training.")
80 parser.add_argument(
81 "--resolution",
82 type=int,
83 default=512,
84 help=(
85 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
86 " resolution"
87 ),
88 )
89 parser.add_argument(
90 "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
91 )
92 parser.add_argument(
93 "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
94 )
95 parser.add_argument("--num_train_epochs", type=int, default=100)
96 parser.add_argument(
97 "--max_train_steps",
98 type=int,
99 default=5000,
100 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
101 )
102 parser.add_argument(
103 "--gradient_accumulation_steps",
104 type=int,
105 default=1,
106 help="Number of updates steps to accumulate before performing a backward/update pass.",
107 )
108 parser.add_argument(
109 "--learning_rate",
110 type=float,
111 default=1e-4,
112 help="Initial learning rate (after the potential warmup period) to use.",
113 )
114 parser.add_argument(
115 "--scale_lr",
116 action="store_true",
117 default=True,
118 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
119 )
120 parser.add_argument(
121 "--lr_scheduler",
122 type=str,
123 default="constant",
124 help=(
125 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
126 ' "constant", "constant_with_warmup"]'
127 ),
128 )
129 parser.add_argument(
130 "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
131 )
132 parser.add_argument("--adam_beta1", type=float, default=0.9,
133 help="The beta1 parameter for the Adam optimizer.")
134 parser.add_argument("--adam_beta2", type=float, default=0.999,
135 help="The beta2 parameter for the Adam optimizer.")
136 parser.add_argument("--adam_weight_decay", type=float,
137 default=1e-2, help="Weight decay to use.")
138 parser.add_argument("--adam_epsilon", type=float, default=1e-08,
139 help="Epsilon value for the Adam optimizer")
140 parser.add_argument(
141 "--mixed_precision",
142 type=str,
143 default="no",
144 choices=["no", "fp16", "bf16"],
145 help=(
146 "Whether to use mixed precision. Choose"
147 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
148 "and an Nvidia Ampere GPU."
149 ),
150 )
151 parser.add_argument("--local_rank", type=int, default=-1,
152 help="For distributed training: local_rank")
153 parser.add_argument(
154 "--checkpoint_frequency",
155 type=int,
156 default=500,
157 help="How often to save a checkpoint and sample image",
158 )
159 parser.add_argument(
160 "--sample_image_size",
161 type=int,
162 default=512,
163 help="Size of sample images",
164 )
165 parser.add_argument(
166 "--stable_sample_batches",
167 type=int,
168 default=1,
169 help="Number of fixed seed sample batches to generate per checkpoint",
170 )
171 parser.add_argument(
172 "--random_sample_batches",
173 type=int,
174 default=1,
175 help="Number of random seed sample batches to generate per checkpoint",
176 )
177 parser.add_argument(
178 "--sample_batch_size",
179 type=int,
180 default=1,
181 help="Number of samples to generate per batch",
182 )
183 parser.add_argument(
184 "--sample_steps",
185 type=int,
186 default=50,
187 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
188 )
189 parser.add_argument(
190 "--resume_from",
191 type=str,
192 default=None,
193 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
194 )
195 parser.add_argument(
196 "--resume_checkpoint",
197 type=str,
198 default=None,
199 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
200 )
201 parser.add_argument(
202 "--config",
203 type=str,
204 default=None,
205 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
206 )
207
208 args = parser.parse_args()
209 if args.resume_from is not None:
210 with open(f"{args.resume_from}/resume.json", 'rt') as f:
211 args = parser.parse_args(
212 namespace=argparse.Namespace(**json.load(f)["args"]))
213 elif args.config is not None:
214 with open(args.config, 'rt') as f:
215 args = parser.parse_args(
216 namespace=argparse.Namespace(**json.load(f)))
217
218 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
219 if env_local_rank != -1 and env_local_rank != args.local_rank:
220 args.local_rank = env_local_rank
221
222 if args.train_data_dir is None:
223 raise ValueError("You must specify --train_data_dir")
224
225 if args.pretrained_model_name_or_path is None:
226 raise ValueError("You must specify --pretrained_model_name_or_path")
227
228 if args.placeholder_token is None:
229 raise ValueError("You must specify --placeholder_token")
230
231 if args.initializer_token is None:
232 raise ValueError("You must specify --initializer_token")
233
234 if args.output_dir is None:
235 raise ValueError("You must specify --output_dir")
236
237 return args
238
239
240def freeze_params(params):
241 for param in params:
242 param.requires_grad = False
243
244
245def save_resume_file(basepath, args, extra={}):
246 info = {"args": vars(args)}
247 info["args"].update(extra)
248 with open(f"{basepath}/resume.json", "w") as f:
249 json.dump(info, f, indent=4)
250
251
252def make_grid(images, rows, cols):
253 w, h = images[0].size
254 grid = Image.new('RGB', size=(cols*w, rows*h))
255 for i, image in enumerate(images):
256 grid.paste(image, box=(i % cols*w, i//cols*h))
257 return grid
258
259
260class Checkpointer:
261 def __init__(
262 self,
263 datamodule,
264 accelerator,
265 vae,
266 unet,
267 tokenizer,
268 placeholder_token,
269 placeholder_token_id,
270 output_dir,
271 sample_image_size,
272 random_sample_batches,
273 sample_batch_size,
274 stable_sample_batches,
275 seed
276 ):
277 self.datamodule = datamodule
278 self.accelerator = accelerator
279 self.vae = vae
280 self.unet = unet
281 self.tokenizer = tokenizer
282 self.placeholder_token = placeholder_token
283 self.placeholder_token_id = placeholder_token_id
284 self.output_dir = output_dir
285 self.sample_image_size = sample_image_size
286 self.seed = seed
287 self.random_sample_batches = random_sample_batches
288 self.sample_batch_size = sample_batch_size
289 self.stable_sample_batches = stable_sample_batches
290
291 @torch.no_grad()
292 def checkpoint(self, step, text_encoder, save_samples=True, path=None):
293 print("Saving checkpoint for step %d..." % step)
294 with self.accelerator.autocast():
295 if path is None:
296 checkpoints_path = f"{self.output_dir}/checkpoints"
297 os.makedirs(checkpoints_path, exist_ok=True)
298
299 unwrapped = self.accelerator.unwrap_model(text_encoder)
300
301 # Save a checkpoint
302 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
303 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
304
305 filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step)
306 if path is not None:
307 torch.save(learned_embeds_dict, path)
308 else:
309 torch.save(learned_embeds_dict,
310 f"{checkpoints_path}/{filename}")
311 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
312 del unwrapped
313 del learned_embeds
314
315 @torch.no_grad()
316 def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
317 samples_path = f"{self.output_dir}/samples/{mode}"
318 os.makedirs(samples_path, exist_ok=True)
319 checker = NoCheck()
320
321 unwrapped = self.accelerator.unwrap_model(text_encoder)
322 # Save a sample image
323 pipeline = StableDiffusionPipeline(
324 text_encoder=unwrapped,
325 vae=self.vae,
326 unet=self.unet,
327 tokenizer=self.tokenizer,
328 scheduler=LMSDiscreteScheduler(
329 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
330 ),
331 safety_checker=NoCheck(),
332 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
333 ).to(self.accelerator.device)
334 pipeline.enable_attention_slicing()
335
336 data = {
337 "training": self.datamodule.train_dataloader(),
338 "validation": self.datamodule.val_dataloader(),
339 }[mode]
340
341 if mode == "validation" and self.stable_sample_batches > 0:
342 stable_latents = torch.randn(
343 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
344 device=pipeline.device,
345 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
346 )
347
348 all_samples = []
349 filename = f"stable_step_%d.png" % (step)
350
351 # Generate and save stable samples
352 for i in range(0, self.stable_sample_batches):
353 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size]
354 samples = pipeline(
355 prompt=prompt,
356 height=self.sample_image_size,
357 latents=stable_latents,
358 width=self.sample_image_size,
359 guidance_scale=guidance_scale,
360 eta=eta,
361 num_inference_steps=num_inference_steps,
362 output_type='pil'
363 )["sample"]
364
365 all_samples += samples
366 del samples
367
368 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
369 image_grid.save(f"{samples_path}/{filename}")
370
371 del all_samples
372 del image_grid
373 del stable_latents
374
375 all_samples = []
376 filename = f"step_%d.png" % (step)
377
378 # Generate and save random samples
379 for i in range(0, self.random_sample_batches):
380 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size]
381 samples = pipeline(
382 prompt=prompt,
383 height=self.sample_image_size,
384 width=self.sample_image_size,
385 guidance_scale=guidance_scale,
386 eta=eta,
387 num_inference_steps=num_inference_steps,
388 output_type='pil'
389 )["sample"]
390
391 all_samples += samples
392 del samples
393
394 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
395 image_grid.save(f"{samples_path}/{filename}")
396
397 del all_samples
398 del image_grid
399
400 del checker
401 del unwrapped
402 del pipeline
403 torch.cuda.empty_cache()
404
405
406class ImageToLatents():
407 def __init__(self, vae):
408 self.vae = vae
409 self.encoded_pixel_values_cache = {}
410
411 @torch.no_grad()
412 def __call__(self, batch):
413 key = "|".join(batch["key"])
414 if self.encoded_pixel_values_cache.get(key, None) is None:
415 self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist
416 latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
417 return latents
418
419
420def main():
421 args = parse_args()
422
423 global_step_offset = 0
424 if args.resume_from is not None:
425 basepath = f"{args.resume_from}"
426 print("Resuming state from %s" % args.resume_from)
427 with open(f"{basepath}/resume.json", 'r') as f:
428 state = json.load(f)
429 global_step_offset = state["args"].get("global_step", 0)
430
431 print("We've trained %d steps so far" % global_step_offset)
432 else:
433 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
434 basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}"
435 os.makedirs(basepath, exist_ok=True)
436
437 accelerator = Accelerator(
438 log_with=LoggerType.TENSORBOARD,
439 logging_dir=f"{basepath}",
440 gradient_accumulation_steps=args.gradient_accumulation_steps,
441 mixed_precision=args.mixed_precision
442 )
443
444 # If passed along, set the training seed now.
445 if args.seed is not None:
446 set_seed(args.seed)
447
448 # Load the tokenizer and add the placeholder token as a additional special token
449 if args.tokenizer_name:
450 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
451 elif args.pretrained_model_name_or_path:
452 tokenizer = CLIPTokenizer.from_pretrained(
453 args.pretrained_model_name_or_path + '/tokenizer'
454 )
455
456 # Add the placeholder token in tokenizer
457 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
458 if num_added_tokens == 0:
459 raise ValueError(
460 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
461 " `placeholder_token` that is not already in the tokenizer."
462 )
463
464 # Convert the initializer_token, placeholder_token to ids
465 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
466 # Check if initializer_token is a single token or a sequence of tokens
467 if args.vectors_per_token % len(initializer_token_ids) != 0:
468 raise ValueError(
469 f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).")
470
471 initializer_token_ids = torch.tensor(initializer_token_ids)
472 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
473
474 # Load models and create wrapper for stable diffusion
475 text_encoder = CLIPTextModel.from_pretrained(
476 args.pretrained_model_name_or_path + '/text_encoder',
477 )
478 vae = AutoencoderKL.from_pretrained(
479 args.pretrained_model_name_or_path + '/vae',
480 )
481 unet = UNet2DConditionModel.from_pretrained(
482 args.pretrained_model_name_or_path + '/unet',
483 )
484
485 slice_size = unet.config.attention_head_dim // 2
486 unet.set_attention_slice(slice_size)
487
488 # Resize the token embeddings as we are adding new special tokens to the tokenizer
489 text_encoder.resize_token_embeddings(len(tokenizer))
490
491 # Initialise the newly added placeholder token with the embeddings of the initializer token
492 token_embeds = text_encoder.get_input_embeddings().weight.data
493
494 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
495
496 if args.resume_checkpoint is not None:
497 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
498 args.placeholder_token]
499 else:
500 token_embeds[placeholder_token_id] = initializer_token_embeddings
501
502 # Freeze vae and unet
503 freeze_params(vae.parameters())
504 freeze_params(unet.parameters())
505 # Freeze all parameters except for the token embeddings in text encoder
506 params_to_freeze = itertools.chain(
507 text_encoder.text_model.encoder.parameters(),
508 text_encoder.text_model.final_layer_norm.parameters(),
509 text_encoder.text_model.embeddings.position_embedding.parameters(),
510 )
511 freeze_params(params_to_freeze)
512
513 if args.scale_lr:
514 args.learning_rate = (
515 args.learning_rate * args.gradient_accumulation_steps *
516 args.train_batch_size * accelerator.num_processes
517 )
518
519 # Initialize the optimizer
520 optimizer = torch.optim.AdamW(
521 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
522 lr=args.learning_rate,
523 betas=(args.adam_beta1, args.adam_beta2),
524 weight_decay=args.adam_weight_decay,
525 eps=args.adam_epsilon,
526 )
527
528 # TODO (patil-suraj): laod scheduler using args
529 noise_scheduler = DDPMScheduler(
530 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
531 )
532
533 datamodule = CSVDataModule(
534 data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer,
535 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats,
536 center_crop=args.center_crop)
537
538 datamodule.prepare_data()
539 datamodule.setup()
540
541 train_dataloader = datamodule.train_dataloader()
542 val_dataloader = datamodule.val_dataloader()
543
544 checkpointer = Checkpointer(
545 datamodule=datamodule,
546 accelerator=accelerator,
547 vae=vae,
548 unet=unet,
549 tokenizer=tokenizer,
550 placeholder_token=args.placeholder_token,
551 placeholder_token_id=placeholder_token_id,
552 output_dir=basepath,
553 sample_image_size=args.sample_image_size,
554 sample_batch_size=args.sample_batch_size,
555 random_sample_batches=args.random_sample_batches,
556 stable_sample_batches=args.stable_sample_batches,
557 seed=args.seed
558 )
559
560 # Scheduler and math around the number of training steps.
561 overrode_max_train_steps = False
562 num_update_steps_per_epoch = math.ceil(
563 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
564 if args.max_train_steps is None:
565 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
566 overrode_max_train_steps = True
567
568 lr_scheduler = get_scheduler(
569 args.lr_scheduler,
570 optimizer=optimizer,
571 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
572 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
573 )
574
575 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
576 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
577 )
578
579 # Move vae and unet to device
580 vae.to(accelerator.device)
581 unet.to(accelerator.device)
582
583 # Keep vae and unet in eval mode as we don't train these
584 vae.eval()
585 unet.eval()
586
587 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
588 num_update_steps_per_epoch = math.ceil(
589 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
590 if overrode_max_train_steps:
591 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
592 # Afterwards we recalculate our number of training epochs
593 args.num_train_epochs = math.ceil(
594 args.max_train_steps / num_update_steps_per_epoch)
595
596 # We need to initialize the trackers we use, and also store our configuration.
597 # The trackers initializes automatically on the main process.
598 if accelerator.is_main_process:
599 accelerator.init_trackers("textual_inversion", config=vars(args))
600
601 # Train!
602 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
603
604 logger.info("***** Running training *****")
605 logger.info(f" Num Epochs = {args.num_train_epochs}")
606 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
607 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
608 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
609 logger.info(f" Total optimization steps = {args.max_train_steps}")
610 # Only show the progress bar once on each machine.
611
612 global_step = 0
613 min_val_loss = np.inf
614
615 imageToLatents = ImageToLatents(vae)
616
617 checkpointer.save_samples(
618 "validation",
619 0,
620 text_encoder,
621 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
622
623 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
624 progress_bar.set_description("Global steps")
625
626 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
627 local_progress_bar.set_description("Steps")
628
629 try:
630 for epoch in range(args.num_train_epochs):
631 local_progress_bar.reset()
632
633 text_encoder.train()
634 train_loss = 0.0
635
636 for step, batch in enumerate(train_dataloader):
637 with accelerator.accumulate(text_encoder):
638 with accelerator.autocast():
639 # Convert images to latent space
640 latents = imageToLatents(batch)
641
642 # Sample noise that we'll add to the latents
643 noise = torch.randn(latents.shape).to(latents.device)
644 bsz = latents.shape[0]
645 # Sample a random timestep for each image
646 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
647 (bsz,), device=latents.device).long()
648
649 # Add noise to the latents according to the noise magnitude at each timestep
650 # (this is the forward diffusion process)
651 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
652
653 # Get the text embedding for conditioning
654 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
655
656 # Predict the noise residual
657 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
658
659 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
660
661 accelerator.backward(loss)
662
663 # Zero out the gradients for all token embeddings except the newly added
664 # embeddings for the concept, as we only want to optimize the concept embeddings
665 if accelerator.num_processes > 1:
666 grads = text_encoder.module.get_input_embeddings().weight.grad
667 else:
668 grads = text_encoder.get_input_embeddings().weight.grad
669 # Get the index for tokens that we want to zero the grads for
670 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
671 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
672
673 optimizer.step()
674 if not accelerator.optimizer_step_was_skipped:
675 lr_scheduler.step()
676 optimizer.zero_grad()
677
678 loss = loss.detach().item()
679 train_loss += loss
680
681 # Checks if the accelerator has performed an optimization step behind the scenes
682 if accelerator.sync_gradients:
683 progress_bar.update(1)
684 local_progress_bar.update(1)
685 global_step += 1
686
687 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
688 checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
689 save_resume_file(basepath, args, {
690 "global_step": global_step + global_step_offset,
691 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
692 })
693 checkpointer.save_samples(
694 "training",
695 global_step + global_step_offset,
696 text_encoder,
697 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
698
699 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
700 local_progress_bar.set_postfix(**logs)
701
702 if global_step >= args.max_train_steps:
703 break
704
705 train_loss /= len(train_dataloader)
706
707 text_encoder.eval()
708 val_loss = 0.0
709
710 for step, batch in enumerate(val_dataloader):
711 with torch.no_grad(), accelerator.autocast():
712 latents = imageToLatents(batch)
713
714 noise = torch.randn(latents.shape).to(latents.device)
715 bsz = latents.shape[0]
716 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
717 (bsz,), device=latents.device).long()
718
719 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
720
721 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
722
723 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
724
725 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
726
727 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
728
729 loss = loss.detach().item()
730 val_loss += loss
731
732 if accelerator.sync_gradients:
733 progress_bar.update(1)
734 local_progress_bar.update(1)
735
736 logs = {"mode": "validation", "loss": loss}
737 local_progress_bar.set_postfix(**logs)
738
739 val_loss /= len(val_dataloader)
740
741 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
742
743 if min_val_loss > val_loss:
744 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
745 min_val_loss = val_loss
746
747 checkpointer.save_samples(
748 "validation",
749 global_step + global_step_offset,
750 text_encoder,
751 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
752
753 accelerator.wait_for_everyone()
754
755 # Create the pipeline using using the trained modules and save it.
756 if accelerator.is_main_process:
757 print("Finished! Saving final checkpoint and resume state.")
758 checkpointer.checkpoint(
759 global_step + global_step_offset,
760 text_encoder,
761 path=f"{basepath}/learned_embeds.bin"
762 )
763
764 save_resume_file(basepath, args, {
765 "global_step": global_step + global_step_offset,
766 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
767 })
768
769 accelerator.end_training()
770
771 except KeyboardInterrupt:
772 if accelerator.is_main_process:
773 print("Interrupted, saving checkpoint and resume state...")
774 checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
775 save_resume_file(basepath, args, {
776 "global_step": global_step + global_step_offset,
777 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
778 })
779 accelerator.end_training()
780 quit()
781
782
783if __name__ == "__main__":
784 main()
diff --git a/pipelines/stable_diffusion/no_check.py b/pipelines/stable_diffusion/no_check.py
new file mode 100644
index 0000000..06c2f72
--- /dev/null
+++ b/pipelines/stable_diffusion/no_check.py
@@ -0,0 +1,13 @@
1from diffusers import ModelMixin
2import torch
3
4
5class NoCheck(ModelMixin):
6 """Can be used in place of safety checker. Use responsibly and at your own risk."""
7
8 def __init__(self):
9 super().__init__()
10 self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
11
12 def forward(self, images=None, **kwargs):
13 return images, [False]
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py
new file mode 100644
index 0000000..ee7fc33
--- /dev/null
+++ b/scripts/convert_original_stable_diffusion_to_diffusers.py
@@ -0,0 +1,690 @@
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Conversion script for the LDM checkpoints. """
16
17import argparse
18import os
19
20import torch
21
22
23try:
24 from omegaconf import OmegaConf
25except ImportError:
26 raise ImportError(
27 "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
28 )
29
30from diffusers import (
31 AutoencoderKL,
32 DDIMScheduler,
33 LDMTextToImagePipeline,
34 LMSDiscreteScheduler,
35 PNDMScheduler,
36 StableDiffusionPipeline,
37 UNet2DConditionModel,
38)
39from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
40from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
41from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
42
43
44def shave_segments(path, n_shave_prefix_segments=1):
45 """
46 Removes segments. Positive values shave the first segments, negative shave the last segments.
47 """
48 if n_shave_prefix_segments >= 0:
49 return ".".join(path.split(".")[n_shave_prefix_segments:])
50 else:
51 return ".".join(path.split(".")[:n_shave_prefix_segments])
52
53
54def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
55 """
56 Updates paths inside resnets to the new naming scheme (local renaming)
57 """
58 mapping = []
59 for old_item in old_list:
60 new_item = old_item.replace("in_layers.0", "norm1")
61 new_item = new_item.replace("in_layers.2", "conv1")
62
63 new_item = new_item.replace("out_layers.0", "norm2")
64 new_item = new_item.replace("out_layers.3", "conv2")
65
66 new_item = new_item.replace("emb_layers.1", "time_emb_proj")
67 new_item = new_item.replace("skip_connection", "conv_shortcut")
68
69 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
70
71 mapping.append({"old": old_item, "new": new_item})
72
73 return mapping
74
75
76def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
77 """
78 Updates paths inside resnets to the new naming scheme (local renaming)
79 """
80 mapping = []
81 for old_item in old_list:
82 new_item = old_item
83
84 new_item = new_item.replace("nin_shortcut", "conv_shortcut")
85 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
86
87 mapping.append({"old": old_item, "new": new_item})
88
89 return mapping
90
91
92def renew_attention_paths(old_list, n_shave_prefix_segments=0):
93 """
94 Updates paths inside attentions to the new naming scheme (local renaming)
95 """
96 mapping = []
97 for old_item in old_list:
98 new_item = old_item
99
100 # new_item = new_item.replace('norm.weight', 'group_norm.weight')
101 # new_item = new_item.replace('norm.bias', 'group_norm.bias')
102
103 # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
104 # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
105
106 # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
107
108 mapping.append({"old": old_item, "new": new_item})
109
110 return mapping
111
112
113def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
114 """
115 Updates paths inside attentions to the new naming scheme (local renaming)
116 """
117 mapping = []
118 for old_item in old_list:
119 new_item = old_item
120
121 new_item = new_item.replace("norm.weight", "group_norm.weight")
122 new_item = new_item.replace("norm.bias", "group_norm.bias")
123
124 new_item = new_item.replace("q.weight", "query.weight")
125 new_item = new_item.replace("q.bias", "query.bias")
126
127 new_item = new_item.replace("k.weight", "key.weight")
128 new_item = new_item.replace("k.bias", "key.bias")
129
130 new_item = new_item.replace("v.weight", "value.weight")
131 new_item = new_item.replace("v.bias", "value.bias")
132
133 new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
134 new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
135
136 new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
137
138 mapping.append({"old": old_item, "new": new_item})
139
140 return mapping
141
142
143def assign_to_checkpoint(
144 paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
145):
146 """
147 This does the final conversion step: take locally converted weights and apply a global renaming
148 to them. It splits attention layers, and takes into account additional replacements
149 that may arise.
150
151 Assigns the weights to the new checkpoint.
152 """
153 assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
154
155 # Splits the attention layers into three variables.
156 if attention_paths_to_split is not None:
157 for path, path_map in attention_paths_to_split.items():
158 old_tensor = old_checkpoint[path]
159 channels = old_tensor.shape[0] // 3
160
161 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
162
163 num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
164
165 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
166 query, key, value = old_tensor.split(channels // num_heads, dim=1)
167
168 checkpoint[path_map["query"]] = query.reshape(target_shape)
169 checkpoint[path_map["key"]] = key.reshape(target_shape)
170 checkpoint[path_map["value"]] = value.reshape(target_shape)
171
172 for path in paths:
173 new_path = path["new"]
174
175 # These have already been assigned
176 if attention_paths_to_split is not None and new_path in attention_paths_to_split:
177 continue
178
179 # Global renaming happens here
180 new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
181 new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
182 new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
183
184 if additional_replacements is not None:
185 for replacement in additional_replacements:
186 new_path = new_path.replace(replacement["old"], replacement["new"])
187
188 # proj_attn.weight has to be converted from conv 1D to linear
189 if "proj_attn.weight" in new_path:
190 checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
191 else:
192 checkpoint[new_path] = old_checkpoint[path["old"]]
193
194
195def conv_attn_to_linear(checkpoint):
196 keys = list(checkpoint.keys())
197 attn_keys = ["query.weight", "key.weight", "value.weight"]
198 for key in keys:
199 if ".".join(key.split(".")[-2:]) in attn_keys:
200 if checkpoint[key].ndim > 2:
201 checkpoint[key] = checkpoint[key][:, :, 0, 0]
202 elif "proj_attn.weight" in key:
203 if checkpoint[key].ndim > 2:
204 checkpoint[key] = checkpoint[key][:, :, 0]
205
206
207def create_unet_diffusers_config(original_config):
208 """
209 Creates a config for the diffusers based on the config of the LDM model.
210 """
211 unet_params = original_config.model.params.unet_config.params
212
213 block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
214
215 down_block_types = []
216 resolution = 1
217 for i in range(len(block_out_channels)):
218 block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
219 down_block_types.append(block_type)
220 if i != len(block_out_channels) - 1:
221 resolution *= 2
222
223 up_block_types = []
224 for i in range(len(block_out_channels)):
225 block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
226 up_block_types.append(block_type)
227 resolution //= 2
228
229 config = dict(
230 sample_size=unet_params.image_size,
231 in_channels=unet_params.in_channels,
232 out_channels=unet_params.out_channels,
233 down_block_types=tuple(down_block_types),
234 up_block_types=tuple(up_block_types),
235 block_out_channels=tuple(block_out_channels),
236 layers_per_block=unet_params.num_res_blocks,
237 cross_attention_dim=unet_params.context_dim,
238 attention_head_dim=unet_params.num_heads,
239 )
240
241 return config
242
243
244def create_vae_diffusers_config(original_config):
245 """
246 Creates a config for the diffusers based on the config of the LDM model.
247 """
248 vae_params = original_config.model.params.first_stage_config.params.ddconfig
249 _ = original_config.model.params.first_stage_config.params.embed_dim
250
251 block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
252 down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
253 up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
254
255 config = dict(
256 sample_size=vae_params.resolution,
257 in_channels=vae_params.in_channels,
258 out_channels=vae_params.out_ch,
259 down_block_types=tuple(down_block_types),
260 up_block_types=tuple(up_block_types),
261 block_out_channels=tuple(block_out_channels),
262 latent_channels=vae_params.z_channels,
263 layers_per_block=vae_params.num_res_blocks,
264 )
265 return config
266
267
268def create_diffusers_schedular(original_config):
269 schedular = DDIMScheduler(
270 num_train_timesteps=original_config.model.params.timesteps,
271 beta_start=original_config.model.params.linear_start,
272 beta_end=original_config.model.params.linear_end,
273 beta_schedule="scaled_linear",
274 )
275 return schedular
276
277
278def create_ldm_bert_config(original_config):
279 bert_params = original_config.model.parms.cond_stage_config.params
280 config = LDMBertConfig(
281 d_model=bert_params.n_embed,
282 encoder_layers=bert_params.n_layer,
283 encoder_ffn_dim=bert_params.n_embed * 4,
284 )
285 return config
286
287
288def convert_ldm_unet_checkpoint(checkpoint, config):
289 """
290 Takes a state dict and a config, and returns a converted checkpoint.
291 """
292
293 # extract state_dict for UNet
294 unet_state_dict = {}
295 unet_key = "model.diffusion_model."
296 keys = list(checkpoint.keys())
297 for key in keys:
298 if key.startswith(unet_key):
299 unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
300
301 new_checkpoint = {}
302
303 new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
304 new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
305 new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
306 new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
307
308 new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
309 new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
310
311 new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
312 new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
313 new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
314 new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
315
316 # Retrieves the keys for the input blocks only
317 num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
318 input_blocks = {
319 layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
320 for layer_id in range(num_input_blocks)
321 }
322
323 # Retrieves the keys for the middle blocks only
324 num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
325 middle_blocks = {
326 layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
327 for layer_id in range(num_middle_blocks)
328 }
329
330 # Retrieves the keys for the output blocks only
331 num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
332 output_blocks = {
333 layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
334 for layer_id in range(num_output_blocks)
335 }
336
337 for i in range(1, num_input_blocks):
338 block_id = (i - 1) // (config["layers_per_block"] + 1)
339 layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
340
341 resnets = [
342 key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
343 ]
344 attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
345
346 if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
347 new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
348 f"input_blocks.{i}.0.op.weight"
349 )
350 new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
351 f"input_blocks.{i}.0.op.bias"
352 )
353
354 paths = renew_resnet_paths(resnets)
355 meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
356 assign_to_checkpoint(
357 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
358 )
359
360 if len(attentions):
361 paths = renew_attention_paths(attentions)
362 meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
363 assign_to_checkpoint(
364 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
365 )
366
367 resnet_0 = middle_blocks[0]
368 attentions = middle_blocks[1]
369 resnet_1 = middle_blocks[2]
370
371 resnet_0_paths = renew_resnet_paths(resnet_0)
372 assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
373
374 resnet_1_paths = renew_resnet_paths(resnet_1)
375 assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
376
377 attentions_paths = renew_attention_paths(attentions)
378 meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
379 assign_to_checkpoint(
380 attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
381 )
382
383 for i in range(num_output_blocks):
384 block_id = i // (config["layers_per_block"] + 1)
385 layer_in_block_id = i % (config["layers_per_block"] + 1)
386 output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
387 output_block_list = {}
388
389 for layer in output_block_layers:
390 layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
391 if layer_id in output_block_list:
392 output_block_list[layer_id].append(layer_name)
393 else:
394 output_block_list[layer_id] = [layer_name]
395
396 if len(output_block_list) > 1:
397 resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
398 attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
399
400 resnet_0_paths = renew_resnet_paths(resnets)
401 paths = renew_resnet_paths(resnets)
402
403 meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
404 assign_to_checkpoint(
405 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
406 )
407
408 if ["conv.weight", "conv.bias"] in output_block_list.values():
409 index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
410 new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
411 f"output_blocks.{i}.{index}.conv.weight"
412 ]
413 new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
414 f"output_blocks.{i}.{index}.conv.bias"
415 ]
416
417 # Clear attentions as they have been attributed above.
418 if len(attentions) == 2:
419 attentions = []
420
421 if len(attentions):
422 paths = renew_attention_paths(attentions)
423 meta_path = {
424 "old": f"output_blocks.{i}.1",
425 "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
426 }
427 assign_to_checkpoint(
428 paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
429 )
430 else:
431 resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
432 for path in resnet_0_paths:
433 old_path = ".".join(["output_blocks", str(i), path["old"]])
434 new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
435
436 new_checkpoint[new_path] = unet_state_dict[old_path]
437
438 return new_checkpoint
439
440
441def convert_ldm_vae_checkpoint(checkpoint, config):
442 # extract state dict for VAE
443 vae_state_dict = {}
444 vae_key = "first_stage_model."
445 keys = list(checkpoint.keys())
446 for key in keys:
447 if key.startswith(vae_key):
448 vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
449
450 new_checkpoint = {}
451
452 new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
453 new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
454 new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
455 new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
456 new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
457 new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
458
459 new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
460 new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
461 new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
462 new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
463 new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
464 new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
465
466 new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
467 new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
468 new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
469 new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
470
471 # Retrieves the keys for the encoder down blocks only
472 num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
473 down_blocks = {
474 layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
475 }
476
477 # Retrieves the keys for the decoder up blocks only
478 num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
479 up_blocks = {
480 layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
481 }
482
483 for i in range(num_down_blocks):
484 resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
485
486 if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
487 new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
488 f"encoder.down.{i}.downsample.conv.weight"
489 )
490 new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
491 f"encoder.down.{i}.downsample.conv.bias"
492 )
493
494 paths = renew_vae_resnet_paths(resnets)
495 meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
496 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
497
498 mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
499 num_mid_res_blocks = 2
500 for i in range(1, num_mid_res_blocks + 1):
501 resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
502
503 paths = renew_vae_resnet_paths(resnets)
504 meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
505 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
506
507 mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
508 paths = renew_vae_attention_paths(mid_attentions)
509 meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
510 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
511 conv_attn_to_linear(new_checkpoint)
512
513 for i in range(num_up_blocks):
514 block_id = num_up_blocks - 1 - i
515 resnets = [
516 key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
517 ]
518
519 if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
520 new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
521 f"decoder.up.{block_id}.upsample.conv.weight"
522 ]
523 new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
524 f"decoder.up.{block_id}.upsample.conv.bias"
525 ]
526
527 paths = renew_vae_resnet_paths(resnets)
528 meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
529 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
530
531 mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
532 num_mid_res_blocks = 2
533 for i in range(1, num_mid_res_blocks + 1):
534 resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
535
536 paths = renew_vae_resnet_paths(resnets)
537 meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
538 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
539
540 mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
541 paths = renew_vae_attention_paths(mid_attentions)
542 meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
543 assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
544 conv_attn_to_linear(new_checkpoint)
545 return new_checkpoint
546
547
548def convert_ldm_bert_checkpoint(checkpoint, config):
549 def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
550 hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
551 hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
552 hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
553
554 hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
555 hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
556
557 def _copy_linear(hf_linear, pt_linear):
558 hf_linear.weight = pt_linear.weight
559 hf_linear.bias = pt_linear.bias
560
561 def _copy_layer(hf_layer, pt_layer):
562 # copy layer norms
563 _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
564 _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
565
566 # copy attn
567 _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
568
569 # copy MLP
570 pt_mlp = pt_layer[1][1]
571 _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
572 _copy_linear(hf_layer.fc2, pt_mlp.net[2])
573
574 def _copy_layers(hf_layers, pt_layers):
575 for i, hf_layer in enumerate(hf_layers):
576 if i != 0:
577 i += i
578 pt_layer = pt_layers[i : i + 2]
579 _copy_layer(hf_layer, pt_layer)
580
581 hf_model = LDMBertModel(config).eval()
582
583 # copy embeds
584 hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
585 hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
586
587 # copy layer norm
588 _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
589
590 # copy hidden layers
591 _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
592
593 _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
594
595 return hf_model
596
597
598if __name__ == "__main__":
599 parser = argparse.ArgumentParser()
600
601 parser.add_argument(
602 "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
603 )
604 # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
605 parser.add_argument(
606 "--original_config_file",
607 default=None,
608 type=str,
609 help="The YAML config file corresponding to the original architecture.",
610 )
611 parser.add_argument(
612 "--scheduler_type",
613 default="pndm",
614 type=str,
615 help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
616 )
617 parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
618
619 args = parser.parse_args()
620
621 if args.original_config_file is None:
622 os.system(
623 "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
624 )
625 args.original_config_file = "./v1-inference.yaml"
626
627 original_config = OmegaConf.load(args.original_config_file)
628 checkpoint = torch.load(args.checkpoint_path)["state_dict"]
629
630 num_train_timesteps = original_config.model.params.timesteps
631 beta_start = original_config.model.params.linear_start
632 beta_end = original_config.model.params.linear_end
633 if args.scheduler_type == "pndm":
634 scheduler = PNDMScheduler(
635 beta_end=beta_end,
636 beta_schedule="scaled_linear",
637 beta_start=beta_start,
638 num_train_timesteps=num_train_timesteps,
639 skip_prk_steps=True,
640 )
641 elif args.scheduler_type == "lms":
642 scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
643 elif args.scheduler_type == "ddim":
644 scheduler = DDIMScheduler(
645 beta_start=beta_start,
646 beta_end=beta_end,
647 beta_schedule="scaled_linear",
648 clip_sample=False,
649 set_alpha_to_one=False,
650 )
651 else:
652 raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
653
654 # Convert the UNet2DConditionModel model.
655 unet_config = create_unet_diffusers_config(original_config)
656 converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
657
658 unet = UNet2DConditionModel(**unet_config)
659 unet.load_state_dict(converted_unet_checkpoint)
660
661 # Convert the VAE model.
662 vae_config = create_vae_diffusers_config(original_config)
663 converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
664
665 vae = AutoencoderKL(**vae_config)
666 vae.load_state_dict(converted_vae_checkpoint)
667
668 # Convert the text model.
669 text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
670 if text_model_type == "FrozenCLIPEmbedder":
671 text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
672 tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
673 safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
674 feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
675 pipe = StableDiffusionPipeline(
676 vae=vae,
677 text_encoder=text_model,
678 tokenizer=tokenizer,
679 unet=unet,
680 scheduler=scheduler,
681 safety_checker=safety_checker,
682 feature_extractor=feature_extractor,
683 )
684 else:
685 text_config = create_ldm_bert_config(original_config)
686 text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
687 tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
688 pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
689
690 pipe.save_pretrained(args.dump_path)
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..75158a0
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,13 @@
1from setuptools import setup, find_packages
2
3setup(
4 name='textual-inversion-diff',
5 version='0.0.1',
6 description='',
7 packages=find_packages(),
8 install_requires=[
9 'torch',
10 'numpy',
11 'tqdm',
12 ],
13)