diff options
| -rw-r--r-- | .gitignore | 164 | ||||
| -rw-r--r-- | .pep8 | 2 | ||||
| -rw-r--r-- | data.py | 145 | ||||
| -rw-r--r-- | environment.yaml | 36 | ||||
| -rw-r--r-- | main.py | 784 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/no_check.py | 13 | ||||
| -rw-r--r-- | scripts/convert_original_stable_diffusion_to_diffusers.py | 690 | ||||
| -rw-r--r-- | setup.py | 13 |
8 files changed, 1847 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..00e7681 --- /dev/null +++ b/.gitignore | |||
| @@ -0,0 +1,164 @@ | |||
| 1 | # Byte-compiled / optimized / DLL files | ||
| 2 | __pycache__/ | ||
| 3 | *.py[cod] | ||
| 4 | *$py.class | ||
| 5 | |||
| 6 | # C extensions | ||
| 7 | *.so | ||
| 8 | |||
| 9 | # Distribution / packaging | ||
| 10 | .Python | ||
| 11 | build/ | ||
| 12 | develop-eggs/ | ||
| 13 | dist/ | ||
| 14 | downloads/ | ||
| 15 | eggs/ | ||
| 16 | .eggs/ | ||
| 17 | lib/ | ||
| 18 | lib64/ | ||
| 19 | parts/ | ||
| 20 | sdist/ | ||
| 21 | var/ | ||
| 22 | wheels/ | ||
| 23 | share/python-wheels/ | ||
| 24 | *.egg-info/ | ||
| 25 | .installed.cfg | ||
| 26 | *.egg | ||
| 27 | MANIFEST | ||
| 28 | |||
| 29 | # PyInstaller | ||
| 30 | # Usually these files are written by a python script from a template | ||
| 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
| 32 | *.manifest | ||
| 33 | *.spec | ||
| 34 | |||
| 35 | # Installer logs | ||
| 36 | pip-log.txt | ||
| 37 | pip-delete-this-directory.txt | ||
| 38 | |||
| 39 | # Unit test / coverage reports | ||
| 40 | htmlcov/ | ||
| 41 | .tox/ | ||
| 42 | .nox/ | ||
| 43 | .coverage | ||
| 44 | .coverage.* | ||
| 45 | .cache | ||
| 46 | nosetests.xml | ||
| 47 | coverage.xml | ||
| 48 | *.cover | ||
| 49 | *.py,cover | ||
| 50 | .hypothesis/ | ||
| 51 | .pytest_cache/ | ||
| 52 | cover/ | ||
| 53 | |||
| 54 | # Translations | ||
| 55 | *.mo | ||
| 56 | *.pot | ||
| 57 | |||
| 58 | # Django stuff: | ||
| 59 | *.log | ||
| 60 | local_settings.py | ||
| 61 | db.sqlite3 | ||
| 62 | db.sqlite3-journal | ||
| 63 | |||
| 64 | # Flask stuff: | ||
| 65 | instance/ | ||
| 66 | .webassets-cache | ||
| 67 | |||
| 68 | # Scrapy stuff: | ||
| 69 | .scrapy | ||
| 70 | |||
| 71 | # Sphinx documentation | ||
| 72 | docs/_build/ | ||
| 73 | |||
| 74 | # PyBuilder | ||
| 75 | .pybuilder/ | ||
| 76 | target/ | ||
| 77 | |||
| 78 | # Jupyter Notebook | ||
| 79 | .ipynb_checkpoints | ||
| 80 | |||
| 81 | # IPython | ||
| 82 | profile_default/ | ||
| 83 | ipython_config.py | ||
| 84 | |||
| 85 | # pyenv | ||
| 86 | # For a library or package, you might want to ignore these files since the code is | ||
| 87 | # intended to run in multiple environments; otherwise, check them in: | ||
| 88 | # .python-version | ||
| 89 | |||
| 90 | # pipenv | ||
| 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
| 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
| 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
| 94 | # install all needed dependencies. | ||
| 95 | #Pipfile.lock | ||
| 96 | |||
| 97 | # poetry | ||
| 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
| 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more | ||
| 100 | # commonly ignored for libraries. | ||
| 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
| 102 | #poetry.lock | ||
| 103 | |||
| 104 | # pdm | ||
| 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
| 106 | #pdm.lock | ||
| 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
| 108 | # in version control. | ||
| 109 | # https://pdm.fming.dev/#use-with-ide | ||
| 110 | .pdm.toml | ||
| 111 | |||
| 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
| 113 | __pypackages__/ | ||
| 114 | |||
| 115 | # Celery stuff | ||
| 116 | celerybeat-schedule | ||
| 117 | celerybeat.pid | ||
| 118 | |||
| 119 | # SageMath parsed files | ||
| 120 | *.sage.py | ||
| 121 | |||
| 122 | # Environments | ||
| 123 | .env | ||
| 124 | .venv | ||
| 125 | env/ | ||
| 126 | venv/ | ||
| 127 | ENV/ | ||
| 128 | env.bak/ | ||
| 129 | venv.bak/ | ||
| 130 | |||
| 131 | # Spyder project settings | ||
| 132 | .spyderproject | ||
| 133 | .spyproject | ||
| 134 | |||
| 135 | # Rope project settings | ||
| 136 | .ropeproject | ||
| 137 | |||
| 138 | # mkdocs documentation | ||
| 139 | /site | ||
| 140 | |||
| 141 | # mypy | ||
| 142 | .mypy_cache/ | ||
| 143 | .dmypy.json | ||
| 144 | dmypy.json | ||
| 145 | |||
| 146 | # Pyre type checker | ||
| 147 | .pyre/ | ||
| 148 | |||
| 149 | # pytype static type analyzer | ||
| 150 | .pytype/ | ||
| 151 | |||
| 152 | # Cython debug symbols | ||
| 153 | cython_debug/ | ||
| 154 | |||
| 155 | # PyCharm | ||
| 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
| 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
| 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear | ||
| 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
| 160 | #.idea/ | ||
| 161 | |||
| 162 | text-inversion-model/ | ||
| 163 | conf.json | ||
| 164 | v1-inference.yaml | ||
| @@ -0,0 +1,2 @@ | |||
| 1 | [pycodestyle] | ||
| 2 | max_line_length = 120 | ||
| @@ -0,0 +1,145 @@ | |||
| 1 | import os | ||
| 2 | import numpy as np | ||
| 3 | import pandas as pd | ||
| 4 | import random | ||
| 5 | import PIL | ||
| 6 | import pytorch_lightning as pl | ||
| 7 | from PIL import Image | ||
| 8 | import torch | ||
| 9 | from torch.utils.data import Dataset, DataLoader, random_split | ||
| 10 | from torchvision import transforms | ||
| 11 | |||
| 12 | |||
| 13 | class CSVDataModule(pl.LightningDataModule): | ||
| 14 | def __init__(self, | ||
| 15 | batch_size, | ||
| 16 | data_root, | ||
| 17 | tokenizer, | ||
| 18 | size=512, | ||
| 19 | repeats=100, | ||
| 20 | interpolation="bicubic", | ||
| 21 | placeholder_token="*", | ||
| 22 | flip_p=0.5, | ||
| 23 | center_crop=False): | ||
| 24 | super().__init__() | ||
| 25 | |||
| 26 | self.data_root = data_root | ||
| 27 | self.tokenizer = tokenizer | ||
| 28 | self.size = size | ||
| 29 | self.repeats = repeats | ||
| 30 | self.placeholder_token = placeholder_token | ||
| 31 | self.center_crop = center_crop | ||
| 32 | self.flip_p = flip_p | ||
| 33 | self.interpolation = interpolation | ||
| 34 | |||
| 35 | self.batch_size = batch_size | ||
| 36 | |||
| 37 | def prepare_data(self): | ||
| 38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | ||
| 39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | ||
| 40 | captions = [caption for caption in metadata['caption'].values] | ||
| 41 | skips = [skip for skip in metadata['skip'].values] | ||
| 42 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | ||
| 43 | |||
| 44 | def setup(self, stage=None): | ||
| 45 | train_set_size = int(len(self.data_full) * 0.8) | ||
| 46 | valid_set_size = len(self.data_full) - train_set_size | ||
| 47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | ||
| 48 | |||
| 49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | ||
| 50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
| 51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | ||
| 52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | ||
| 53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | ||
| 54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | ||
| 55 | |||
| 56 | def train_dataloader(self): | ||
| 57 | return self.train_dataloader_ | ||
| 58 | |||
| 59 | def val_dataloader(self): | ||
| 60 | return self.val_dataloader_ | ||
| 61 | |||
| 62 | |||
| 63 | class CSVDataset(Dataset): | ||
| 64 | def __init__(self, | ||
| 65 | data, | ||
| 66 | tokenizer, | ||
| 67 | size=512, | ||
| 68 | repeats=1, | ||
| 69 | interpolation="bicubic", | ||
| 70 | flip_p=0.5, | ||
| 71 | placeholder_token="*", | ||
| 72 | center_crop=False, | ||
| 73 | ): | ||
| 74 | |||
| 75 | self.data = data | ||
| 76 | self.tokenizer = tokenizer | ||
| 77 | |||
| 78 | self.num_images = len(self.data) | ||
| 79 | self._length = self.num_images * repeats | ||
| 80 | |||
| 81 | self.placeholder_token = placeholder_token | ||
| 82 | |||
| 83 | self.size = size | ||
| 84 | self.center_crop = center_crop | ||
| 85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
| 86 | "bilinear": PIL.Image.BILINEAR, | ||
| 87 | "bicubic": PIL.Image.BICUBIC, | ||
| 88 | "lanczos": PIL.Image.LANCZOS, | ||
| 89 | }[interpolation] | ||
| 90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
| 91 | |||
| 92 | self.cache = {} | ||
| 93 | |||
| 94 | def __len__(self): | ||
| 95 | return self._length | ||
| 96 | |||
| 97 | def get_example(self, i, flipped): | ||
| 98 | image_path, text = self.data[i % self.num_images] | ||
| 99 | |||
| 100 | if image_path in self.cache: | ||
| 101 | return self.cache[image_path] | ||
| 102 | |||
| 103 | example = {} | ||
| 104 | image = Image.open(image_path) | ||
| 105 | |||
| 106 | if not image.mode == "RGB": | ||
| 107 | image = image.convert("RGB") | ||
| 108 | |||
| 109 | text = text.format(self.placeholder_token) | ||
| 110 | |||
| 111 | example["prompt"] = text | ||
| 112 | example["input_ids"] = self.tokenizer( | ||
| 113 | text, | ||
| 114 | padding="max_length", | ||
| 115 | truncation=True, | ||
| 116 | max_length=self.tokenizer.model_max_length, | ||
| 117 | return_tensors="pt", | ||
| 118 | ).input_ids[0] | ||
| 119 | |||
| 120 | # default to score-sde preprocessing | ||
| 121 | img = np.array(image).astype(np.uint8) | ||
| 122 | |||
| 123 | if self.center_crop: | ||
| 124 | crop = min(img.shape[0], img.shape[1]) | ||
| 125 | h, w, = img.shape[0], img.shape[1] | ||
| 126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
| 127 | (w - crop) // 2:(w + crop) // 2] | ||
| 128 | |||
| 129 | image = Image.fromarray(img) | ||
| 130 | image = image.resize((self.size, self.size), | ||
| 131 | resample=self.interpolation) | ||
| 132 | image = self.flip(image) | ||
| 133 | image = np.array(image).astype(np.uint8) | ||
| 134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
| 135 | |||
| 136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
| 137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
| 138 | |||
| 139 | self.cache[image_path] = example | ||
| 140 | return example | ||
| 141 | |||
| 142 | def __getitem__(self, i): | ||
| 143 | flipped = random.choice([False, True]) | ||
| 144 | example = self.get_example(i, flipped) | ||
| 145 | return example | ||
diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..a460158 --- /dev/null +++ b/environment.yaml | |||
| @@ -0,0 +1,36 @@ | |||
| 1 | name: ldd | ||
| 2 | channels: | ||
| 3 | - pytorch | ||
| 4 | - defaults | ||
| 5 | dependencies: | ||
| 6 | - cudatoolkit=11.3 | ||
| 7 | - numpy=1.22.3 | ||
| 8 | - pip=20.3 | ||
| 9 | - python=3.8.10 | ||
| 10 | - pytorch=1.12.1 | ||
| 11 | - torchvision=0.13.1 | ||
| 12 | - pandas=1.4.3 | ||
| 13 | - pip: | ||
| 14 | - -e . | ||
| 15 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers | ||
| 16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip | ||
| 17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion | ||
| 18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion | ||
| 19 | - accelerate==0.12.0 | ||
| 20 | - albumentations==1.1.0 | ||
| 21 | - diffusers==0.3.0 | ||
| 22 | - einops==0.4.1 | ||
| 23 | - imageio-ffmpeg==0.4.7 | ||
| 24 | - imageio==2.14.1 | ||
| 25 | - kornia==0.6 | ||
| 26 | - pudb==2019.2 | ||
| 27 | - omegaconf==2.1.1 | ||
| 28 | - opencv-python-headless==4.6.0.66 | ||
| 29 | - python-slugify>=6.1.2 | ||
| 30 | - pytorch-lightning==1.7.7 | ||
| 31 | - setuptools==59.5.0 | ||
| 32 | - streamlit>=0.73.1 | ||
| 33 | - test-tube>=0.7.5 | ||
| 34 | - torch-fidelity==0.3.0 | ||
| 35 | - torchmetrics==0.9.3 | ||
| 36 | - transformers==4.19.2 | ||
| @@ -0,0 +1,784 @@ | |||
| 1 | import argparse | ||
| 2 | import itertools | ||
| 3 | import math | ||
| 4 | import os | ||
| 5 | import random | ||
| 6 | import datetime | ||
| 7 | from pathlib import Path | ||
| 8 | from typing import Optional | ||
| 9 | |||
| 10 | import numpy as np | ||
| 11 | import torch | ||
| 12 | import torch.nn as nn | ||
| 13 | import torch.nn.functional as F | ||
| 14 | import torch.utils.checkpoint | ||
| 15 | from torch.utils.data import Dataset | ||
| 16 | |||
| 17 | import PIL | ||
| 18 | from accelerate import Accelerator | ||
| 19 | from accelerate.logging import get_logger | ||
| 20 | from accelerate.utils import LoggerType, set_seed | ||
| 21 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | ||
| 22 | from diffusers.optimization import get_scheduler | ||
| 23 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
| 24 | from einops import rearrange | ||
| 25 | from pipelines.stable_diffusion.no_check import NoCheck | ||
| 26 | from huggingface_hub import HfFolder, Repository, whoami | ||
| 27 | from PIL import Image | ||
| 28 | from tqdm.auto import tqdm | ||
| 29 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
| 30 | from slugify import slugify | ||
| 31 | import json | ||
| 32 | import os | ||
| 33 | import sys | ||
| 34 | |||
| 35 | from data import CSVDataModule | ||
| 36 | |||
| 37 | logger = get_logger(__name__) | ||
| 38 | |||
| 39 | |||
| 40 | def parse_args(): | ||
| 41 | parser = argparse.ArgumentParser( | ||
| 42 | description="Simple example of a training script.") | ||
| 43 | parser.add_argument( | ||
| 44 | "--pretrained_model_name_or_path", | ||
| 45 | type=str, | ||
| 46 | default=None, | ||
| 47 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
| 48 | ) | ||
| 49 | parser.add_argument( | ||
| 50 | "--tokenizer_name", | ||
| 51 | type=str, | ||
| 52 | default=None, | ||
| 53 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
| 54 | ) | ||
| 55 | parser.add_argument( | ||
| 56 | "--train_data_dir", type=str, default=None, help="A folder containing the training data." | ||
| 57 | ) | ||
| 58 | parser.add_argument( | ||
| 59 | "--placeholder_token", | ||
| 60 | type=str, | ||
| 61 | default=None, | ||
| 62 | help="A token to use as a placeholder for the concept.", | ||
| 63 | ) | ||
| 64 | parser.add_argument( | ||
| 65 | "--initializer_token", type=str, default=None, help="A token to use as initializer word." | ||
| 66 | ) | ||
| 67 | parser.add_argument( | ||
| 68 | "--vectors_per_token", type=int, default=1, help="Vectors per token." | ||
| 69 | ) | ||
| 70 | parser.add_argument("--repeats", type=int, default=100, | ||
| 71 | help="How many times to repeat the training data.") | ||
| 72 | parser.add_argument( | ||
| 73 | "--output_dir", | ||
| 74 | type=str, | ||
| 75 | default="text-inversion-model", | ||
| 76 | help="The output directory where the model predictions and checkpoints will be written.", | ||
| 77 | ) | ||
| 78 | parser.add_argument("--seed", type=int, default=None, | ||
| 79 | help="A seed for reproducible training.") | ||
| 80 | parser.add_argument( | ||
| 81 | "--resolution", | ||
| 82 | type=int, | ||
| 83 | default=512, | ||
| 84 | help=( | ||
| 85 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
| 86 | " resolution" | ||
| 87 | ), | ||
| 88 | ) | ||
| 89 | parser.add_argument( | ||
| 90 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" | ||
| 91 | ) | ||
| 92 | parser.add_argument( | ||
| 93 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." | ||
| 94 | ) | ||
| 95 | parser.add_argument("--num_train_epochs", type=int, default=100) | ||
| 96 | parser.add_argument( | ||
| 97 | "--max_train_steps", | ||
| 98 | type=int, | ||
| 99 | default=5000, | ||
| 100 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 101 | ) | ||
| 102 | parser.add_argument( | ||
| 103 | "--gradient_accumulation_steps", | ||
| 104 | type=int, | ||
| 105 | default=1, | ||
| 106 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
| 107 | ) | ||
| 108 | parser.add_argument( | ||
| 109 | "--learning_rate", | ||
| 110 | type=float, | ||
| 111 | default=1e-4, | ||
| 112 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 113 | ) | ||
| 114 | parser.add_argument( | ||
| 115 | "--scale_lr", | ||
| 116 | action="store_true", | ||
| 117 | default=True, | ||
| 118 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
| 119 | ) | ||
| 120 | parser.add_argument( | ||
| 121 | "--lr_scheduler", | ||
| 122 | type=str, | ||
| 123 | default="constant", | ||
| 124 | help=( | ||
| 125 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
| 126 | ' "constant", "constant_with_warmup"]' | ||
| 127 | ), | ||
| 128 | ) | ||
| 129 | parser.add_argument( | ||
| 130 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | ||
| 131 | ) | ||
| 132 | parser.add_argument("--adam_beta1", type=float, default=0.9, | ||
| 133 | help="The beta1 parameter for the Adam optimizer.") | ||
| 134 | parser.add_argument("--adam_beta2", type=float, default=0.999, | ||
| 135 | help="The beta2 parameter for the Adam optimizer.") | ||
| 136 | parser.add_argument("--adam_weight_decay", type=float, | ||
| 137 | default=1e-2, help="Weight decay to use.") | ||
| 138 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, | ||
| 139 | help="Epsilon value for the Adam optimizer") | ||
| 140 | parser.add_argument( | ||
| 141 | "--mixed_precision", | ||
| 142 | type=str, | ||
| 143 | default="no", | ||
| 144 | choices=["no", "fp16", "bf16"], | ||
| 145 | help=( | ||
| 146 | "Whether to use mixed precision. Choose" | ||
| 147 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
| 148 | "and an Nvidia Ampere GPU." | ||
| 149 | ), | ||
| 150 | ) | ||
| 151 | parser.add_argument("--local_rank", type=int, default=-1, | ||
| 152 | help="For distributed training: local_rank") | ||
| 153 | parser.add_argument( | ||
| 154 | "--checkpoint_frequency", | ||
| 155 | type=int, | ||
| 156 | default=500, | ||
| 157 | help="How often to save a checkpoint and sample image", | ||
| 158 | ) | ||
| 159 | parser.add_argument( | ||
| 160 | "--sample_image_size", | ||
| 161 | type=int, | ||
| 162 | default=512, | ||
| 163 | help="Size of sample images", | ||
| 164 | ) | ||
| 165 | parser.add_argument( | ||
| 166 | "--stable_sample_batches", | ||
| 167 | type=int, | ||
| 168 | default=1, | ||
| 169 | help="Number of fixed seed sample batches to generate per checkpoint", | ||
| 170 | ) | ||
| 171 | parser.add_argument( | ||
| 172 | "--random_sample_batches", | ||
| 173 | type=int, | ||
| 174 | default=1, | ||
| 175 | help="Number of random seed sample batches to generate per checkpoint", | ||
| 176 | ) | ||
| 177 | parser.add_argument( | ||
| 178 | "--sample_batch_size", | ||
| 179 | type=int, | ||
| 180 | default=1, | ||
| 181 | help="Number of samples to generate per batch", | ||
| 182 | ) | ||
| 183 | parser.add_argument( | ||
| 184 | "--sample_steps", | ||
| 185 | type=int, | ||
| 186 | default=50, | ||
| 187 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
| 188 | ) | ||
| 189 | parser.add_argument( | ||
| 190 | "--resume_from", | ||
| 191 | type=str, | ||
| 192 | default=None, | ||
| 193 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | ||
| 194 | ) | ||
| 195 | parser.add_argument( | ||
| 196 | "--resume_checkpoint", | ||
| 197 | type=str, | ||
| 198 | default=None, | ||
| 199 | help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." | ||
| 200 | ) | ||
| 201 | parser.add_argument( | ||
| 202 | "--config", | ||
| 203 | type=str, | ||
| 204 | default=None, | ||
| 205 | help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." | ||
| 206 | ) | ||
| 207 | |||
| 208 | args = parser.parse_args() | ||
| 209 | if args.resume_from is not None: | ||
| 210 | with open(f"{args.resume_from}/resume.json", 'rt') as f: | ||
| 211 | args = parser.parse_args( | ||
| 212 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 213 | elif args.config is not None: | ||
| 214 | with open(args.config, 'rt') as f: | ||
| 215 | args = parser.parse_args( | ||
| 216 | namespace=argparse.Namespace(**json.load(f))) | ||
| 217 | |||
| 218 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
| 219 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
| 220 | args.local_rank = env_local_rank | ||
| 221 | |||
| 222 | if args.train_data_dir is None: | ||
| 223 | raise ValueError("You must specify --train_data_dir") | ||
| 224 | |||
| 225 | if args.pretrained_model_name_or_path is None: | ||
| 226 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
| 227 | |||
| 228 | if args.placeholder_token is None: | ||
| 229 | raise ValueError("You must specify --placeholder_token") | ||
| 230 | |||
| 231 | if args.initializer_token is None: | ||
| 232 | raise ValueError("You must specify --initializer_token") | ||
| 233 | |||
| 234 | if args.output_dir is None: | ||
| 235 | raise ValueError("You must specify --output_dir") | ||
| 236 | |||
| 237 | return args | ||
| 238 | |||
| 239 | |||
| 240 | def freeze_params(params): | ||
| 241 | for param in params: | ||
| 242 | param.requires_grad = False | ||
| 243 | |||
| 244 | |||
| 245 | def save_resume_file(basepath, args, extra={}): | ||
| 246 | info = {"args": vars(args)} | ||
| 247 | info["args"].update(extra) | ||
| 248 | with open(f"{basepath}/resume.json", "w") as f: | ||
| 249 | json.dump(info, f, indent=4) | ||
| 250 | |||
| 251 | |||
| 252 | def make_grid(images, rows, cols): | ||
| 253 | w, h = images[0].size | ||
| 254 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
| 255 | for i, image in enumerate(images): | ||
| 256 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
| 257 | return grid | ||
| 258 | |||
| 259 | |||
| 260 | class Checkpointer: | ||
| 261 | def __init__( | ||
| 262 | self, | ||
| 263 | datamodule, | ||
| 264 | accelerator, | ||
| 265 | vae, | ||
| 266 | unet, | ||
| 267 | tokenizer, | ||
| 268 | placeholder_token, | ||
| 269 | placeholder_token_id, | ||
| 270 | output_dir, | ||
| 271 | sample_image_size, | ||
| 272 | random_sample_batches, | ||
| 273 | sample_batch_size, | ||
| 274 | stable_sample_batches, | ||
| 275 | seed | ||
| 276 | ): | ||
| 277 | self.datamodule = datamodule | ||
| 278 | self.accelerator = accelerator | ||
| 279 | self.vae = vae | ||
| 280 | self.unet = unet | ||
| 281 | self.tokenizer = tokenizer | ||
| 282 | self.placeholder_token = placeholder_token | ||
| 283 | self.placeholder_token_id = placeholder_token_id | ||
| 284 | self.output_dir = output_dir | ||
| 285 | self.sample_image_size = sample_image_size | ||
| 286 | self.seed = seed | ||
| 287 | self.random_sample_batches = random_sample_batches | ||
| 288 | self.sample_batch_size = sample_batch_size | ||
| 289 | self.stable_sample_batches = stable_sample_batches | ||
| 290 | |||
| 291 | @torch.no_grad() | ||
| 292 | def checkpoint(self, step, text_encoder, save_samples=True, path=None): | ||
| 293 | print("Saving checkpoint for step %d..." % step) | ||
| 294 | with self.accelerator.autocast(): | ||
| 295 | if path is None: | ||
| 296 | checkpoints_path = f"{self.output_dir}/checkpoints" | ||
| 297 | os.makedirs(checkpoints_path, exist_ok=True) | ||
| 298 | |||
| 299 | unwrapped = self.accelerator.unwrap_model(text_encoder) | ||
| 300 | |||
| 301 | # Save a checkpoint | ||
| 302 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
| 303 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
| 304 | |||
| 305 | filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) | ||
| 306 | if path is not None: | ||
| 307 | torch.save(learned_embeds_dict, path) | ||
| 308 | else: | ||
| 309 | torch.save(learned_embeds_dict, | ||
| 310 | f"{checkpoints_path}/{filename}") | ||
| 311 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | ||
| 312 | del unwrapped | ||
| 313 | del learned_embeds | ||
| 314 | |||
| 315 | @torch.no_grad() | ||
| 316 | def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): | ||
| 317 | samples_path = f"{self.output_dir}/samples/{mode}" | ||
| 318 | os.makedirs(samples_path, exist_ok=True) | ||
| 319 | checker = NoCheck() | ||
| 320 | |||
| 321 | unwrapped = self.accelerator.unwrap_model(text_encoder) | ||
| 322 | # Save a sample image | ||
| 323 | pipeline = StableDiffusionPipeline( | ||
| 324 | text_encoder=unwrapped, | ||
| 325 | vae=self.vae, | ||
| 326 | unet=self.unet, | ||
| 327 | tokenizer=self.tokenizer, | ||
| 328 | scheduler=LMSDiscreteScheduler( | ||
| 329 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 330 | ), | ||
| 331 | safety_checker=NoCheck(), | ||
| 332 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
| 333 | ).to(self.accelerator.device) | ||
| 334 | pipeline.enable_attention_slicing() | ||
| 335 | |||
| 336 | data = { | ||
| 337 | "training": self.datamodule.train_dataloader(), | ||
| 338 | "validation": self.datamodule.val_dataloader(), | ||
| 339 | }[mode] | ||
| 340 | |||
| 341 | if mode == "validation" and self.stable_sample_batches > 0: | ||
| 342 | stable_latents = torch.randn( | ||
| 343 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | ||
| 344 | device=pipeline.device, | ||
| 345 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | ||
| 346 | ) | ||
| 347 | |||
| 348 | all_samples = [] | ||
| 349 | filename = f"stable_step_%d.png" % (step) | ||
| 350 | |||
| 351 | # Generate and save stable samples | ||
| 352 | for i in range(0, self.stable_sample_batches): | ||
| 353 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | ||
| 354 | samples = pipeline( | ||
| 355 | prompt=prompt, | ||
| 356 | height=self.sample_image_size, | ||
| 357 | latents=stable_latents, | ||
| 358 | width=self.sample_image_size, | ||
| 359 | guidance_scale=guidance_scale, | ||
| 360 | eta=eta, | ||
| 361 | num_inference_steps=num_inference_steps, | ||
| 362 | output_type='pil' | ||
| 363 | )["sample"] | ||
| 364 | |||
| 365 | all_samples += samples | ||
| 366 | del samples | ||
| 367 | |||
| 368 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
| 369 | image_grid.save(f"{samples_path}/{filename}") | ||
| 370 | |||
| 371 | del all_samples | ||
| 372 | del image_grid | ||
| 373 | del stable_latents | ||
| 374 | |||
| 375 | all_samples = [] | ||
| 376 | filename = f"step_%d.png" % (step) | ||
| 377 | |||
| 378 | # Generate and save random samples | ||
| 379 | for i in range(0, self.random_sample_batches): | ||
| 380 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | ||
| 381 | samples = pipeline( | ||
| 382 | prompt=prompt, | ||
| 383 | height=self.sample_image_size, | ||
| 384 | width=self.sample_image_size, | ||
| 385 | guidance_scale=guidance_scale, | ||
| 386 | eta=eta, | ||
| 387 | num_inference_steps=num_inference_steps, | ||
| 388 | output_type='pil' | ||
| 389 | )["sample"] | ||
| 390 | |||
| 391 | all_samples += samples | ||
| 392 | del samples | ||
| 393 | |||
| 394 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | ||
| 395 | image_grid.save(f"{samples_path}/{filename}") | ||
| 396 | |||
| 397 | del all_samples | ||
| 398 | del image_grid | ||
| 399 | |||
| 400 | del checker | ||
| 401 | del unwrapped | ||
| 402 | del pipeline | ||
| 403 | torch.cuda.empty_cache() | ||
| 404 | |||
| 405 | |||
| 406 | class ImageToLatents(): | ||
| 407 | def __init__(self, vae): | ||
| 408 | self.vae = vae | ||
| 409 | self.encoded_pixel_values_cache = {} | ||
| 410 | |||
| 411 | @torch.no_grad() | ||
| 412 | def __call__(self, batch): | ||
| 413 | key = "|".join(batch["key"]) | ||
| 414 | if self.encoded_pixel_values_cache.get(key, None) is None: | ||
| 415 | self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist | ||
| 416 | latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215 | ||
| 417 | return latents | ||
| 418 | |||
| 419 | |||
| 420 | def main(): | ||
| 421 | args = parse_args() | ||
| 422 | |||
| 423 | global_step_offset = 0 | ||
| 424 | if args.resume_from is not None: | ||
| 425 | basepath = f"{args.resume_from}" | ||
| 426 | print("Resuming state from %s" % args.resume_from) | ||
| 427 | with open(f"{basepath}/resume.json", 'r') as f: | ||
| 428 | state = json.load(f) | ||
| 429 | global_step_offset = state["args"].get("global_step", 0) | ||
| 430 | |||
| 431 | print("We've trained %d steps so far" % global_step_offset) | ||
| 432 | else: | ||
| 433 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 434 | basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" | ||
| 435 | os.makedirs(basepath, exist_ok=True) | ||
| 436 | |||
| 437 | accelerator = Accelerator( | ||
| 438 | log_with=LoggerType.TENSORBOARD, | ||
| 439 | logging_dir=f"{basepath}", | ||
| 440 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 441 | mixed_precision=args.mixed_precision | ||
| 442 | ) | ||
| 443 | |||
| 444 | # If passed along, set the training seed now. | ||
| 445 | if args.seed is not None: | ||
| 446 | set_seed(args.seed) | ||
| 447 | |||
| 448 | # Load the tokenizer and add the placeholder token as a additional special token | ||
| 449 | if args.tokenizer_name: | ||
| 450 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | ||
| 451 | elif args.pretrained_model_name_or_path: | ||
| 452 | tokenizer = CLIPTokenizer.from_pretrained( | ||
| 453 | args.pretrained_model_name_or_path + '/tokenizer' | ||
| 454 | ) | ||
| 455 | |||
| 456 | # Add the placeholder token in tokenizer | ||
| 457 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 458 | if num_added_tokens == 0: | ||
| 459 | raise ValueError( | ||
| 460 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" | ||
| 461 | " `placeholder_token` that is not already in the tokenizer." | ||
| 462 | ) | ||
| 463 | |||
| 464 | # Convert the initializer_token, placeholder_token to ids | ||
| 465 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | ||
| 466 | # Check if initializer_token is a single token or a sequence of tokens | ||
| 467 | if args.vectors_per_token % len(initializer_token_ids) != 0: | ||
| 468 | raise ValueError( | ||
| 469 | f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") | ||
| 470 | |||
| 471 | initializer_token_ids = torch.tensor(initializer_token_ids) | ||
| 472 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
| 473 | |||
| 474 | # Load models and create wrapper for stable diffusion | ||
| 475 | text_encoder = CLIPTextModel.from_pretrained( | ||
| 476 | args.pretrained_model_name_or_path + '/text_encoder', | ||
| 477 | ) | ||
| 478 | vae = AutoencoderKL.from_pretrained( | ||
| 479 | args.pretrained_model_name_or_path + '/vae', | ||
| 480 | ) | ||
| 481 | unet = UNet2DConditionModel.from_pretrained( | ||
| 482 | args.pretrained_model_name_or_path + '/unet', | ||
| 483 | ) | ||
| 484 | |||
| 485 | slice_size = unet.config.attention_head_dim // 2 | ||
| 486 | unet.set_attention_slice(slice_size) | ||
| 487 | |||
| 488 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
| 489 | text_encoder.resize_token_embeddings(len(tokenizer)) | ||
| 490 | |||
| 491 | # Initialise the newly added placeholder token with the embeddings of the initializer token | ||
| 492 | token_embeds = text_encoder.get_input_embeddings().weight.data | ||
| 493 | |||
| 494 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 495 | |||
| 496 | if args.resume_checkpoint is not None: | ||
| 497 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ | ||
| 498 | args.placeholder_token] | ||
| 499 | else: | ||
| 500 | token_embeds[placeholder_token_id] = initializer_token_embeddings | ||
| 501 | |||
| 502 | # Freeze vae and unet | ||
| 503 | freeze_params(vae.parameters()) | ||
| 504 | freeze_params(unet.parameters()) | ||
| 505 | # Freeze all parameters except for the token embeddings in text encoder | ||
| 506 | params_to_freeze = itertools.chain( | ||
| 507 | text_encoder.text_model.encoder.parameters(), | ||
| 508 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 509 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 510 | ) | ||
| 511 | freeze_params(params_to_freeze) | ||
| 512 | |||
| 513 | if args.scale_lr: | ||
| 514 | args.learning_rate = ( | ||
| 515 | args.learning_rate * args.gradient_accumulation_steps * | ||
| 516 | args.train_batch_size * accelerator.num_processes | ||
| 517 | ) | ||
| 518 | |||
| 519 | # Initialize the optimizer | ||
| 520 | optimizer = torch.optim.AdamW( | ||
| 521 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | ||
| 522 | lr=args.learning_rate, | ||
| 523 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 524 | weight_decay=args.adam_weight_decay, | ||
| 525 | eps=args.adam_epsilon, | ||
| 526 | ) | ||
| 527 | |||
| 528 | # TODO (patil-suraj): laod scheduler using args | ||
| 529 | noise_scheduler = DDPMScheduler( | ||
| 530 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" | ||
| 531 | ) | ||
| 532 | |||
| 533 | datamodule = CSVDataModule( | ||
| 534 | data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, | ||
| 535 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, | ||
| 536 | center_crop=args.center_crop) | ||
| 537 | |||
| 538 | datamodule.prepare_data() | ||
| 539 | datamodule.setup() | ||
| 540 | |||
| 541 | train_dataloader = datamodule.train_dataloader() | ||
| 542 | val_dataloader = datamodule.val_dataloader() | ||
| 543 | |||
| 544 | checkpointer = Checkpointer( | ||
| 545 | datamodule=datamodule, | ||
| 546 | accelerator=accelerator, | ||
| 547 | vae=vae, | ||
| 548 | unet=unet, | ||
| 549 | tokenizer=tokenizer, | ||
| 550 | placeholder_token=args.placeholder_token, | ||
| 551 | placeholder_token_id=placeholder_token_id, | ||
| 552 | output_dir=basepath, | ||
| 553 | sample_image_size=args.sample_image_size, | ||
| 554 | sample_batch_size=args.sample_batch_size, | ||
| 555 | random_sample_batches=args.random_sample_batches, | ||
| 556 | stable_sample_batches=args.stable_sample_batches, | ||
| 557 | seed=args.seed | ||
| 558 | ) | ||
| 559 | |||
| 560 | # Scheduler and math around the number of training steps. | ||
| 561 | overrode_max_train_steps = False | ||
| 562 | num_update_steps_per_epoch = math.ceil( | ||
| 563 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
| 564 | if args.max_train_steps is None: | ||
| 565 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
| 566 | overrode_max_train_steps = True | ||
| 567 | |||
| 568 | lr_scheduler = get_scheduler( | ||
| 569 | args.lr_scheduler, | ||
| 570 | optimizer=optimizer, | ||
| 571 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, | ||
| 572 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 573 | ) | ||
| 574 | |||
| 575 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 576 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 577 | ) | ||
| 578 | |||
| 579 | # Move vae and unet to device | ||
| 580 | vae.to(accelerator.device) | ||
| 581 | unet.to(accelerator.device) | ||
| 582 | |||
| 583 | # Keep vae and unet in eval mode as we don't train these | ||
| 584 | vae.eval() | ||
| 585 | unet.eval() | ||
| 586 | |||
| 587 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
| 588 | num_update_steps_per_epoch = math.ceil( | ||
| 589 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
| 590 | if overrode_max_train_steps: | ||
| 591 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | ||
| 592 | # Afterwards we recalculate our number of training epochs | ||
| 593 | args.num_train_epochs = math.ceil( | ||
| 594 | args.max_train_steps / num_update_steps_per_epoch) | ||
| 595 | |||
| 596 | # We need to initialize the trackers we use, and also store our configuration. | ||
| 597 | # The trackers initializes automatically on the main process. | ||
| 598 | if accelerator.is_main_process: | ||
| 599 | accelerator.init_trackers("textual_inversion", config=vars(args)) | ||
| 600 | |||
| 601 | # Train! | ||
| 602 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | ||
| 603 | |||
| 604 | logger.info("***** Running training *****") | ||
| 605 | logger.info(f" Num Epochs = {args.num_train_epochs}") | ||
| 606 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | ||
| 607 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | ||
| 608 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | ||
| 609 | logger.info(f" Total optimization steps = {args.max_train_steps}") | ||
| 610 | # Only show the progress bar once on each machine. | ||
| 611 | |||
| 612 | global_step = 0 | ||
| 613 | min_val_loss = np.inf | ||
| 614 | |||
| 615 | imageToLatents = ImageToLatents(vae) | ||
| 616 | |||
| 617 | checkpointer.save_samples( | ||
| 618 | "validation", | ||
| 619 | 0, | ||
| 620 | text_encoder, | ||
| 621 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 622 | |||
| 623 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | ||
| 624 | progress_bar.set_description("Global steps") | ||
| 625 | |||
| 626 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | ||
| 627 | local_progress_bar.set_description("Steps") | ||
| 628 | |||
| 629 | try: | ||
| 630 | for epoch in range(args.num_train_epochs): | ||
| 631 | local_progress_bar.reset() | ||
| 632 | |||
| 633 | text_encoder.train() | ||
| 634 | train_loss = 0.0 | ||
| 635 | |||
| 636 | for step, batch in enumerate(train_dataloader): | ||
| 637 | with accelerator.accumulate(text_encoder): | ||
| 638 | with accelerator.autocast(): | ||
| 639 | # Convert images to latent space | ||
| 640 | latents = imageToLatents(batch) | ||
| 641 | |||
| 642 | # Sample noise that we'll add to the latents | ||
| 643 | noise = torch.randn(latents.shape).to(latents.device) | ||
| 644 | bsz = latents.shape[0] | ||
| 645 | # Sample a random timestep for each image | ||
| 646 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | ||
| 647 | (bsz,), device=latents.device).long() | ||
| 648 | |||
| 649 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 650 | # (this is the forward diffusion process) | ||
| 651 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 652 | |||
| 653 | # Get the text embedding for conditioning | ||
| 654 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
| 655 | |||
| 656 | # Predict the noise residual | ||
| 657 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 658 | |||
| 659 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 660 | |||
| 661 | accelerator.backward(loss) | ||
| 662 | |||
| 663 | # Zero out the gradients for all token embeddings except the newly added | ||
| 664 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
| 665 | if accelerator.num_processes > 1: | ||
| 666 | grads = text_encoder.module.get_input_embeddings().weight.grad | ||
| 667 | else: | ||
| 668 | grads = text_encoder.get_input_embeddings().weight.grad | ||
| 669 | # Get the index for tokens that we want to zero the grads for | ||
| 670 | index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id | ||
| 671 | grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0) | ||
| 672 | |||
| 673 | optimizer.step() | ||
| 674 | if not accelerator.optimizer_step_was_skipped: | ||
| 675 | lr_scheduler.step() | ||
| 676 | optimizer.zero_grad() | ||
| 677 | |||
| 678 | loss = loss.detach().item() | ||
| 679 | train_loss += loss | ||
| 680 | |||
| 681 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
| 682 | if accelerator.sync_gradients: | ||
| 683 | progress_bar.update(1) | ||
| 684 | local_progress_bar.update(1) | ||
| 685 | global_step += 1 | ||
| 686 | |||
| 687 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
| 688 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | ||
| 689 | save_resume_file(basepath, args, { | ||
| 690 | "global_step": global_step + global_step_offset, | ||
| 691 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
| 692 | }) | ||
| 693 | checkpointer.save_samples( | ||
| 694 | "training", | ||
| 695 | global_step + global_step_offset, | ||
| 696 | text_encoder, | ||
| 697 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 698 | |||
| 699 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | ||
| 700 | local_progress_bar.set_postfix(**logs) | ||
| 701 | |||
| 702 | if global_step >= args.max_train_steps: | ||
| 703 | break | ||
| 704 | |||
| 705 | train_loss /= len(train_dataloader) | ||
| 706 | |||
| 707 | text_encoder.eval() | ||
| 708 | val_loss = 0.0 | ||
| 709 | |||
| 710 | for step, batch in enumerate(val_dataloader): | ||
| 711 | with torch.no_grad(), accelerator.autocast(): | ||
| 712 | latents = imageToLatents(batch) | ||
| 713 | |||
| 714 | noise = torch.randn(latents.shape).to(latents.device) | ||
| 715 | bsz = latents.shape[0] | ||
| 716 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | ||
| 717 | (bsz,), device=latents.device).long() | ||
| 718 | |||
| 719 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 720 | |||
| 721 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
| 722 | |||
| 723 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 724 | |||
| 725 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | ||
| 726 | |||
| 727 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 728 | |||
| 729 | loss = loss.detach().item() | ||
| 730 | val_loss += loss | ||
| 731 | |||
| 732 | if accelerator.sync_gradients: | ||
| 733 | progress_bar.update(1) | ||
| 734 | local_progress_bar.update(1) | ||
| 735 | |||
| 736 | logs = {"mode": "validation", "loss": loss} | ||
| 737 | local_progress_bar.set_postfix(**logs) | ||
| 738 | |||
| 739 | val_loss /= len(val_dataloader) | ||
| 740 | |||
| 741 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | ||
| 742 | |||
| 743 | if min_val_loss > val_loss: | ||
| 744 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 745 | min_val_loss = val_loss | ||
| 746 | |||
| 747 | checkpointer.save_samples( | ||
| 748 | "validation", | ||
| 749 | global_step + global_step_offset, | ||
| 750 | text_encoder, | ||
| 751 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 752 | |||
| 753 | accelerator.wait_for_everyone() | ||
| 754 | |||
| 755 | # Create the pipeline using using the trained modules and save it. | ||
| 756 | if accelerator.is_main_process: | ||
| 757 | print("Finished! Saving final checkpoint and resume state.") | ||
| 758 | checkpointer.checkpoint( | ||
| 759 | global_step + global_step_offset, | ||
| 760 | text_encoder, | ||
| 761 | path=f"{basepath}/learned_embeds.bin" | ||
| 762 | ) | ||
| 763 | |||
| 764 | save_resume_file(basepath, args, { | ||
| 765 | "global_step": global_step + global_step_offset, | ||
| 766 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
| 767 | }) | ||
| 768 | |||
| 769 | accelerator.end_training() | ||
| 770 | |||
| 771 | except KeyboardInterrupt: | ||
| 772 | if accelerator.is_main_process: | ||
| 773 | print("Interrupted, saving checkpoint and resume state...") | ||
| 774 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | ||
| 775 | save_resume_file(basepath, args, { | ||
| 776 | "global_step": global_step + global_step_offset, | ||
| 777 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
| 778 | }) | ||
| 779 | accelerator.end_training() | ||
| 780 | quit() | ||
| 781 | |||
| 782 | |||
| 783 | if __name__ == "__main__": | ||
| 784 | main() | ||
diff --git a/pipelines/stable_diffusion/no_check.py b/pipelines/stable_diffusion/no_check.py new file mode 100644 index 0000000..06c2f72 --- /dev/null +++ b/pipelines/stable_diffusion/no_check.py | |||
| @@ -0,0 +1,13 @@ | |||
| 1 | from diffusers import ModelMixin | ||
| 2 | import torch | ||
| 3 | |||
| 4 | |||
| 5 | class NoCheck(ModelMixin): | ||
| 6 | """Can be used in place of safety checker. Use responsibly and at your own risk.""" | ||
| 7 | |||
| 8 | def __init__(self): | ||
| 9 | super().__init__() | ||
| 10 | self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3))) | ||
| 11 | |||
| 12 | def forward(self, images=None, **kwargs): | ||
| 13 | return images, [False] | ||
diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py new file mode 100644 index 0000000..ee7fc33 --- /dev/null +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py | |||
| @@ -0,0 +1,690 @@ | |||
| 1 | # coding=utf-8 | ||
| 2 | # Copyright 2022 The HuggingFace Inc. team. | ||
| 3 | # | ||
| 4 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 5 | # you may not use this file except in compliance with the License. | ||
| 6 | # You may obtain a copy of the License at | ||
| 7 | # | ||
| 8 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 9 | # | ||
| 10 | # Unless required by applicable law or agreed to in writing, software | ||
| 11 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 13 | # See the License for the specific language governing permissions and | ||
| 14 | # limitations under the License. | ||
| 15 | """ Conversion script for the LDM checkpoints. """ | ||
| 16 | |||
| 17 | import argparse | ||
| 18 | import os | ||
| 19 | |||
| 20 | import torch | ||
| 21 | |||
| 22 | |||
| 23 | try: | ||
| 24 | from omegaconf import OmegaConf | ||
| 25 | except ImportError: | ||
| 26 | raise ImportError( | ||
| 27 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." | ||
| 28 | ) | ||
| 29 | |||
| 30 | from diffusers import ( | ||
| 31 | AutoencoderKL, | ||
| 32 | DDIMScheduler, | ||
| 33 | LDMTextToImagePipeline, | ||
| 34 | LMSDiscreteScheduler, | ||
| 35 | PNDMScheduler, | ||
| 36 | StableDiffusionPipeline, | ||
| 37 | UNet2DConditionModel, | ||
| 38 | ) | ||
| 39 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel | ||
| 40 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
| 41 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer | ||
| 42 | |||
| 43 | |||
| 44 | def shave_segments(path, n_shave_prefix_segments=1): | ||
| 45 | """ | ||
| 46 | Removes segments. Positive values shave the first segments, negative shave the last segments. | ||
| 47 | """ | ||
| 48 | if n_shave_prefix_segments >= 0: | ||
| 49 | return ".".join(path.split(".")[n_shave_prefix_segments:]) | ||
| 50 | else: | ||
| 51 | return ".".join(path.split(".")[:n_shave_prefix_segments]) | ||
| 52 | |||
| 53 | |||
| 54 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
| 55 | """ | ||
| 56 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
| 57 | """ | ||
| 58 | mapping = [] | ||
| 59 | for old_item in old_list: | ||
| 60 | new_item = old_item.replace("in_layers.0", "norm1") | ||
| 61 | new_item = new_item.replace("in_layers.2", "conv1") | ||
| 62 | |||
| 63 | new_item = new_item.replace("out_layers.0", "norm2") | ||
| 64 | new_item = new_item.replace("out_layers.3", "conv2") | ||
| 65 | |||
| 66 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") | ||
| 67 | new_item = new_item.replace("skip_connection", "conv_shortcut") | ||
| 68 | |||
| 69 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 70 | |||
| 71 | mapping.append({"old": old_item, "new": new_item}) | ||
| 72 | |||
| 73 | return mapping | ||
| 74 | |||
| 75 | |||
| 76 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): | ||
| 77 | """ | ||
| 78 | Updates paths inside resnets to the new naming scheme (local renaming) | ||
| 79 | """ | ||
| 80 | mapping = [] | ||
| 81 | for old_item in old_list: | ||
| 82 | new_item = old_item | ||
| 83 | |||
| 84 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") | ||
| 85 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 86 | |||
| 87 | mapping.append({"old": old_item, "new": new_item}) | ||
| 88 | |||
| 89 | return mapping | ||
| 90 | |||
| 91 | |||
| 92 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): | ||
| 93 | """ | ||
| 94 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
| 95 | """ | ||
| 96 | mapping = [] | ||
| 97 | for old_item in old_list: | ||
| 98 | new_item = old_item | ||
| 99 | |||
| 100 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') | ||
| 101 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') | ||
| 102 | |||
| 103 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') | ||
| 104 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') | ||
| 105 | |||
| 106 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 107 | |||
| 108 | mapping.append({"old": old_item, "new": new_item}) | ||
| 109 | |||
| 110 | return mapping | ||
| 111 | |||
| 112 | |||
| 113 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): | ||
| 114 | """ | ||
| 115 | Updates paths inside attentions to the new naming scheme (local renaming) | ||
| 116 | """ | ||
| 117 | mapping = [] | ||
| 118 | for old_item in old_list: | ||
| 119 | new_item = old_item | ||
| 120 | |||
| 121 | new_item = new_item.replace("norm.weight", "group_norm.weight") | ||
| 122 | new_item = new_item.replace("norm.bias", "group_norm.bias") | ||
| 123 | |||
| 124 | new_item = new_item.replace("q.weight", "query.weight") | ||
| 125 | new_item = new_item.replace("q.bias", "query.bias") | ||
| 126 | |||
| 127 | new_item = new_item.replace("k.weight", "key.weight") | ||
| 128 | new_item = new_item.replace("k.bias", "key.bias") | ||
| 129 | |||
| 130 | new_item = new_item.replace("v.weight", "value.weight") | ||
| 131 | new_item = new_item.replace("v.bias", "value.bias") | ||
| 132 | |||
| 133 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") | ||
| 134 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") | ||
| 135 | |||
| 136 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) | ||
| 137 | |||
| 138 | mapping.append({"old": old_item, "new": new_item}) | ||
| 139 | |||
| 140 | return mapping | ||
| 141 | |||
| 142 | |||
| 143 | def assign_to_checkpoint( | ||
| 144 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None | ||
| 145 | ): | ||
| 146 | """ | ||
| 147 | This does the final conversion step: take locally converted weights and apply a global renaming | ||
| 148 | to them. It splits attention layers, and takes into account additional replacements | ||
| 149 | that may arise. | ||
| 150 | |||
| 151 | Assigns the weights to the new checkpoint. | ||
| 152 | """ | ||
| 153 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." | ||
| 154 | |||
| 155 | # Splits the attention layers into three variables. | ||
| 156 | if attention_paths_to_split is not None: | ||
| 157 | for path, path_map in attention_paths_to_split.items(): | ||
| 158 | old_tensor = old_checkpoint[path] | ||
| 159 | channels = old_tensor.shape[0] // 3 | ||
| 160 | |||
| 161 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) | ||
| 162 | |||
| 163 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 | ||
| 164 | |||
| 165 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) | ||
| 166 | query, key, value = old_tensor.split(channels // num_heads, dim=1) | ||
| 167 | |||
| 168 | checkpoint[path_map["query"]] = query.reshape(target_shape) | ||
| 169 | checkpoint[path_map["key"]] = key.reshape(target_shape) | ||
| 170 | checkpoint[path_map["value"]] = value.reshape(target_shape) | ||
| 171 | |||
| 172 | for path in paths: | ||
| 173 | new_path = path["new"] | ||
| 174 | |||
| 175 | # These have already been assigned | ||
| 176 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: | ||
| 177 | continue | ||
| 178 | |||
| 179 | # Global renaming happens here | ||
| 180 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") | ||
| 181 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") | ||
| 182 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") | ||
| 183 | |||
| 184 | if additional_replacements is not None: | ||
| 185 | for replacement in additional_replacements: | ||
| 186 | new_path = new_path.replace(replacement["old"], replacement["new"]) | ||
| 187 | |||
| 188 | # proj_attn.weight has to be converted from conv 1D to linear | ||
| 189 | if "proj_attn.weight" in new_path: | ||
| 190 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] | ||
| 191 | else: | ||
| 192 | checkpoint[new_path] = old_checkpoint[path["old"]] | ||
| 193 | |||
| 194 | |||
| 195 | def conv_attn_to_linear(checkpoint): | ||
| 196 | keys = list(checkpoint.keys()) | ||
| 197 | attn_keys = ["query.weight", "key.weight", "value.weight"] | ||
| 198 | for key in keys: | ||
| 199 | if ".".join(key.split(".")[-2:]) in attn_keys: | ||
| 200 | if checkpoint[key].ndim > 2: | ||
| 201 | checkpoint[key] = checkpoint[key][:, :, 0, 0] | ||
| 202 | elif "proj_attn.weight" in key: | ||
| 203 | if checkpoint[key].ndim > 2: | ||
| 204 | checkpoint[key] = checkpoint[key][:, :, 0] | ||
| 205 | |||
| 206 | |||
| 207 | def create_unet_diffusers_config(original_config): | ||
| 208 | """ | ||
| 209 | Creates a config for the diffusers based on the config of the LDM model. | ||
| 210 | """ | ||
| 211 | unet_params = original_config.model.params.unet_config.params | ||
| 212 | |||
| 213 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] | ||
| 214 | |||
| 215 | down_block_types = [] | ||
| 216 | resolution = 1 | ||
| 217 | for i in range(len(block_out_channels)): | ||
| 218 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" | ||
| 219 | down_block_types.append(block_type) | ||
| 220 | if i != len(block_out_channels) - 1: | ||
| 221 | resolution *= 2 | ||
| 222 | |||
| 223 | up_block_types = [] | ||
| 224 | for i in range(len(block_out_channels)): | ||
| 225 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" | ||
| 226 | up_block_types.append(block_type) | ||
| 227 | resolution //= 2 | ||
| 228 | |||
| 229 | config = dict( | ||
| 230 | sample_size=unet_params.image_size, | ||
| 231 | in_channels=unet_params.in_channels, | ||
| 232 | out_channels=unet_params.out_channels, | ||
| 233 | down_block_types=tuple(down_block_types), | ||
| 234 | up_block_types=tuple(up_block_types), | ||
| 235 | block_out_channels=tuple(block_out_channels), | ||
| 236 | layers_per_block=unet_params.num_res_blocks, | ||
| 237 | cross_attention_dim=unet_params.context_dim, | ||
| 238 | attention_head_dim=unet_params.num_heads, | ||
| 239 | ) | ||
| 240 | |||
| 241 | return config | ||
| 242 | |||
| 243 | |||
| 244 | def create_vae_diffusers_config(original_config): | ||
| 245 | """ | ||
| 246 | Creates a config for the diffusers based on the config of the LDM model. | ||
| 247 | """ | ||
| 248 | vae_params = original_config.model.params.first_stage_config.params.ddconfig | ||
| 249 | _ = original_config.model.params.first_stage_config.params.embed_dim | ||
| 250 | |||
| 251 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] | ||
| 252 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) | ||
| 253 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) | ||
| 254 | |||
| 255 | config = dict( | ||
| 256 | sample_size=vae_params.resolution, | ||
| 257 | in_channels=vae_params.in_channels, | ||
| 258 | out_channels=vae_params.out_ch, | ||
| 259 | down_block_types=tuple(down_block_types), | ||
| 260 | up_block_types=tuple(up_block_types), | ||
| 261 | block_out_channels=tuple(block_out_channels), | ||
| 262 | latent_channels=vae_params.z_channels, | ||
| 263 | layers_per_block=vae_params.num_res_blocks, | ||
| 264 | ) | ||
| 265 | return config | ||
| 266 | |||
| 267 | |||
| 268 | def create_diffusers_schedular(original_config): | ||
| 269 | schedular = DDIMScheduler( | ||
| 270 | num_train_timesteps=original_config.model.params.timesteps, | ||
| 271 | beta_start=original_config.model.params.linear_start, | ||
| 272 | beta_end=original_config.model.params.linear_end, | ||
| 273 | beta_schedule="scaled_linear", | ||
| 274 | ) | ||
| 275 | return schedular | ||
| 276 | |||
| 277 | |||
| 278 | def create_ldm_bert_config(original_config): | ||
| 279 | bert_params = original_config.model.parms.cond_stage_config.params | ||
| 280 | config = LDMBertConfig( | ||
| 281 | d_model=bert_params.n_embed, | ||
| 282 | encoder_layers=bert_params.n_layer, | ||
| 283 | encoder_ffn_dim=bert_params.n_embed * 4, | ||
| 284 | ) | ||
| 285 | return config | ||
| 286 | |||
| 287 | |||
| 288 | def convert_ldm_unet_checkpoint(checkpoint, config): | ||
| 289 | """ | ||
| 290 | Takes a state dict and a config, and returns a converted checkpoint. | ||
| 291 | """ | ||
| 292 | |||
| 293 | # extract state_dict for UNet | ||
| 294 | unet_state_dict = {} | ||
| 295 | unet_key = "model.diffusion_model." | ||
| 296 | keys = list(checkpoint.keys()) | ||
| 297 | for key in keys: | ||
| 298 | if key.startswith(unet_key): | ||
| 299 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) | ||
| 300 | |||
| 301 | new_checkpoint = {} | ||
| 302 | |||
| 303 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] | ||
| 304 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] | ||
| 305 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] | ||
| 306 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] | ||
| 307 | |||
| 308 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] | ||
| 309 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] | ||
| 310 | |||
| 311 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] | ||
| 312 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] | ||
| 313 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] | ||
| 314 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] | ||
| 315 | |||
| 316 | # Retrieves the keys for the input blocks only | ||
| 317 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) | ||
| 318 | input_blocks = { | ||
| 319 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] | ||
| 320 | for layer_id in range(num_input_blocks) | ||
| 321 | } | ||
| 322 | |||
| 323 | # Retrieves the keys for the middle blocks only | ||
| 324 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) | ||
| 325 | middle_blocks = { | ||
| 326 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] | ||
| 327 | for layer_id in range(num_middle_blocks) | ||
| 328 | } | ||
| 329 | |||
| 330 | # Retrieves the keys for the output blocks only | ||
| 331 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) | ||
| 332 | output_blocks = { | ||
| 333 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] | ||
| 334 | for layer_id in range(num_output_blocks) | ||
| 335 | } | ||
| 336 | |||
| 337 | for i in range(1, num_input_blocks): | ||
| 338 | block_id = (i - 1) // (config["layers_per_block"] + 1) | ||
| 339 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) | ||
| 340 | |||
| 341 | resnets = [ | ||
| 342 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | ||
| 343 | ] | ||
| 344 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | ||
| 345 | |||
| 346 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: | ||
| 347 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( | ||
| 348 | f"input_blocks.{i}.0.op.weight" | ||
| 349 | ) | ||
| 350 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( | ||
| 351 | f"input_blocks.{i}.0.op.bias" | ||
| 352 | ) | ||
| 353 | |||
| 354 | paths = renew_resnet_paths(resnets) | ||
| 355 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
| 356 | assign_to_checkpoint( | ||
| 357 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 358 | ) | ||
| 359 | |||
| 360 | if len(attentions): | ||
| 361 | paths = renew_attention_paths(attentions) | ||
| 362 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} | ||
| 363 | assign_to_checkpoint( | ||
| 364 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 365 | ) | ||
| 366 | |||
| 367 | resnet_0 = middle_blocks[0] | ||
| 368 | attentions = middle_blocks[1] | ||
| 369 | resnet_1 = middle_blocks[2] | ||
| 370 | |||
| 371 | resnet_0_paths = renew_resnet_paths(resnet_0) | ||
| 372 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) | ||
| 373 | |||
| 374 | resnet_1_paths = renew_resnet_paths(resnet_1) | ||
| 375 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) | ||
| 376 | |||
| 377 | attentions_paths = renew_attention_paths(attentions) | ||
| 378 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} | ||
| 379 | assign_to_checkpoint( | ||
| 380 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 381 | ) | ||
| 382 | |||
| 383 | for i in range(num_output_blocks): | ||
| 384 | block_id = i // (config["layers_per_block"] + 1) | ||
| 385 | layer_in_block_id = i % (config["layers_per_block"] + 1) | ||
| 386 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] | ||
| 387 | output_block_list = {} | ||
| 388 | |||
| 389 | for layer in output_block_layers: | ||
| 390 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) | ||
| 391 | if layer_id in output_block_list: | ||
| 392 | output_block_list[layer_id].append(layer_name) | ||
| 393 | else: | ||
| 394 | output_block_list[layer_id] = [layer_name] | ||
| 395 | |||
| 396 | if len(output_block_list) > 1: | ||
| 397 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] | ||
| 398 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] | ||
| 399 | |||
| 400 | resnet_0_paths = renew_resnet_paths(resnets) | ||
| 401 | paths = renew_resnet_paths(resnets) | ||
| 402 | |||
| 403 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} | ||
| 404 | assign_to_checkpoint( | ||
| 405 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 406 | ) | ||
| 407 | |||
| 408 | if ["conv.weight", "conv.bias"] in output_block_list.values(): | ||
| 409 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) | ||
| 410 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ | ||
| 411 | f"output_blocks.{i}.{index}.conv.weight" | ||
| 412 | ] | ||
| 413 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ | ||
| 414 | f"output_blocks.{i}.{index}.conv.bias" | ||
| 415 | ] | ||
| 416 | |||
| 417 | # Clear attentions as they have been attributed above. | ||
| 418 | if len(attentions) == 2: | ||
| 419 | attentions = [] | ||
| 420 | |||
| 421 | if len(attentions): | ||
| 422 | paths = renew_attention_paths(attentions) | ||
| 423 | meta_path = { | ||
| 424 | "old": f"output_blocks.{i}.1", | ||
| 425 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", | ||
| 426 | } | ||
| 427 | assign_to_checkpoint( | ||
| 428 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config | ||
| 429 | ) | ||
| 430 | else: | ||
| 431 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) | ||
| 432 | for path in resnet_0_paths: | ||
| 433 | old_path = ".".join(["output_blocks", str(i), path["old"]]) | ||
| 434 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) | ||
| 435 | |||
| 436 | new_checkpoint[new_path] = unet_state_dict[old_path] | ||
| 437 | |||
| 438 | return new_checkpoint | ||
| 439 | |||
| 440 | |||
| 441 | def convert_ldm_vae_checkpoint(checkpoint, config): | ||
| 442 | # extract state dict for VAE | ||
| 443 | vae_state_dict = {} | ||
| 444 | vae_key = "first_stage_model." | ||
| 445 | keys = list(checkpoint.keys()) | ||
| 446 | for key in keys: | ||
| 447 | if key.startswith(vae_key): | ||
| 448 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) | ||
| 449 | |||
| 450 | new_checkpoint = {} | ||
| 451 | |||
| 452 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] | ||
| 453 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] | ||
| 454 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] | ||
| 455 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] | ||
| 456 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] | ||
| 457 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] | ||
| 458 | |||
| 459 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] | ||
| 460 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] | ||
| 461 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] | ||
| 462 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] | ||
| 463 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] | ||
| 464 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] | ||
| 465 | |||
| 466 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] | ||
| 467 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] | ||
| 468 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] | ||
| 469 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] | ||
| 470 | |||
| 471 | # Retrieves the keys for the encoder down blocks only | ||
| 472 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) | ||
| 473 | down_blocks = { | ||
| 474 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) | ||
| 475 | } | ||
| 476 | |||
| 477 | # Retrieves the keys for the decoder up blocks only | ||
| 478 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) | ||
| 479 | up_blocks = { | ||
| 480 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) | ||
| 481 | } | ||
| 482 | |||
| 483 | for i in range(num_down_blocks): | ||
| 484 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] | ||
| 485 | |||
| 486 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: | ||
| 487 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( | ||
| 488 | f"encoder.down.{i}.downsample.conv.weight" | ||
| 489 | ) | ||
| 490 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( | ||
| 491 | f"encoder.down.{i}.downsample.conv.bias" | ||
| 492 | ) | ||
| 493 | |||
| 494 | paths = renew_vae_resnet_paths(resnets) | ||
| 495 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} | ||
| 496 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 497 | |||
| 498 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] | ||
| 499 | num_mid_res_blocks = 2 | ||
| 500 | for i in range(1, num_mid_res_blocks + 1): | ||
| 501 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] | ||
| 502 | |||
| 503 | paths = renew_vae_resnet_paths(resnets) | ||
| 504 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
| 505 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 506 | |||
| 507 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] | ||
| 508 | paths = renew_vae_attention_paths(mid_attentions) | ||
| 509 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
| 510 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 511 | conv_attn_to_linear(new_checkpoint) | ||
| 512 | |||
| 513 | for i in range(num_up_blocks): | ||
| 514 | block_id = num_up_blocks - 1 - i | ||
| 515 | resnets = [ | ||
| 516 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key | ||
| 517 | ] | ||
| 518 | |||
| 519 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: | ||
| 520 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ | ||
| 521 | f"decoder.up.{block_id}.upsample.conv.weight" | ||
| 522 | ] | ||
| 523 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ | ||
| 524 | f"decoder.up.{block_id}.upsample.conv.bias" | ||
| 525 | ] | ||
| 526 | |||
| 527 | paths = renew_vae_resnet_paths(resnets) | ||
| 528 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} | ||
| 529 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 530 | |||
| 531 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] | ||
| 532 | num_mid_res_blocks = 2 | ||
| 533 | for i in range(1, num_mid_res_blocks + 1): | ||
| 534 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] | ||
| 535 | |||
| 536 | paths = renew_vae_resnet_paths(resnets) | ||
| 537 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} | ||
| 538 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 539 | |||
| 540 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] | ||
| 541 | paths = renew_vae_attention_paths(mid_attentions) | ||
| 542 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} | ||
| 543 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) | ||
| 544 | conv_attn_to_linear(new_checkpoint) | ||
| 545 | return new_checkpoint | ||
| 546 | |||
| 547 | |||
| 548 | def convert_ldm_bert_checkpoint(checkpoint, config): | ||
| 549 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
| 550 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight | ||
| 551 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight | ||
| 552 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight | ||
| 553 | |||
| 554 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight | ||
| 555 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias | ||
| 556 | |||
| 557 | def _copy_linear(hf_linear, pt_linear): | ||
| 558 | hf_linear.weight = pt_linear.weight | ||
| 559 | hf_linear.bias = pt_linear.bias | ||
| 560 | |||
| 561 | def _copy_layer(hf_layer, pt_layer): | ||
| 562 | # copy layer norms | ||
| 563 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) | ||
| 564 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) | ||
| 565 | |||
| 566 | # copy attn | ||
| 567 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) | ||
| 568 | |||
| 569 | # copy MLP | ||
| 570 | pt_mlp = pt_layer[1][1] | ||
| 571 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) | ||
| 572 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) | ||
| 573 | |||
| 574 | def _copy_layers(hf_layers, pt_layers): | ||
| 575 | for i, hf_layer in enumerate(hf_layers): | ||
| 576 | if i != 0: | ||
| 577 | i += i | ||
| 578 | pt_layer = pt_layers[i : i + 2] | ||
| 579 | _copy_layer(hf_layer, pt_layer) | ||
| 580 | |||
| 581 | hf_model = LDMBertModel(config).eval() | ||
| 582 | |||
| 583 | # copy embeds | ||
| 584 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight | ||
| 585 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight | ||
| 586 | |||
| 587 | # copy layer norm | ||
| 588 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) | ||
| 589 | |||
| 590 | # copy hidden layers | ||
| 591 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) | ||
| 592 | |||
| 593 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) | ||
| 594 | |||
| 595 | return hf_model | ||
| 596 | |||
| 597 | |||
| 598 | if __name__ == "__main__": | ||
| 599 | parser = argparse.ArgumentParser() | ||
| 600 | |||
| 601 | parser.add_argument( | ||
| 602 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." | ||
| 603 | ) | ||
| 604 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml | ||
| 605 | parser.add_argument( | ||
| 606 | "--original_config_file", | ||
| 607 | default=None, | ||
| 608 | type=str, | ||
| 609 | help="The YAML config file corresponding to the original architecture.", | ||
| 610 | ) | ||
| 611 | parser.add_argument( | ||
| 612 | "--scheduler_type", | ||
| 613 | default="pndm", | ||
| 614 | type=str, | ||
| 615 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", | ||
| 616 | ) | ||
| 617 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") | ||
| 618 | |||
| 619 | args = parser.parse_args() | ||
| 620 | |||
| 621 | if args.original_config_file is None: | ||
| 622 | os.system( | ||
| 623 | "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" | ||
| 624 | ) | ||
| 625 | args.original_config_file = "./v1-inference.yaml" | ||
| 626 | |||
| 627 | original_config = OmegaConf.load(args.original_config_file) | ||
| 628 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] | ||
| 629 | |||
| 630 | num_train_timesteps = original_config.model.params.timesteps | ||
| 631 | beta_start = original_config.model.params.linear_start | ||
| 632 | beta_end = original_config.model.params.linear_end | ||
| 633 | if args.scheduler_type == "pndm": | ||
| 634 | scheduler = PNDMScheduler( | ||
| 635 | beta_end=beta_end, | ||
| 636 | beta_schedule="scaled_linear", | ||
| 637 | beta_start=beta_start, | ||
| 638 | num_train_timesteps=num_train_timesteps, | ||
| 639 | skip_prk_steps=True, | ||
| 640 | ) | ||
| 641 | elif args.scheduler_type == "lms": | ||
| 642 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") | ||
| 643 | elif args.scheduler_type == "ddim": | ||
| 644 | scheduler = DDIMScheduler( | ||
| 645 | beta_start=beta_start, | ||
| 646 | beta_end=beta_end, | ||
| 647 | beta_schedule="scaled_linear", | ||
| 648 | clip_sample=False, | ||
| 649 | set_alpha_to_one=False, | ||
| 650 | ) | ||
| 651 | else: | ||
| 652 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") | ||
| 653 | |||
| 654 | # Convert the UNet2DConditionModel model. | ||
| 655 | unet_config = create_unet_diffusers_config(original_config) | ||
| 656 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) | ||
| 657 | |||
| 658 | unet = UNet2DConditionModel(**unet_config) | ||
| 659 | unet.load_state_dict(converted_unet_checkpoint) | ||
| 660 | |||
| 661 | # Convert the VAE model. | ||
| 662 | vae_config = create_vae_diffusers_config(original_config) | ||
| 663 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) | ||
| 664 | |||
| 665 | vae = AutoencoderKL(**vae_config) | ||
| 666 | vae.load_state_dict(converted_vae_checkpoint) | ||
| 667 | |||
| 668 | # Convert the text model. | ||
| 669 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] | ||
| 670 | if text_model_type == "FrozenCLIPEmbedder": | ||
| 671 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 672 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | ||
| 673 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
| 674 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") | ||
| 675 | pipe = StableDiffusionPipeline( | ||
| 676 | vae=vae, | ||
| 677 | text_encoder=text_model, | ||
| 678 | tokenizer=tokenizer, | ||
| 679 | unet=unet, | ||
| 680 | scheduler=scheduler, | ||
| 681 | safety_checker=safety_checker, | ||
| 682 | feature_extractor=feature_extractor, | ||
| 683 | ) | ||
| 684 | else: | ||
| 685 | text_config = create_ldm_bert_config(original_config) | ||
| 686 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) | ||
| 687 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | ||
| 688 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
| 689 | |||
| 690 | pipe.save_pretrained(args.dump_path) | ||
diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..75158a0 --- /dev/null +++ b/setup.py | |||
| @@ -0,0 +1,13 @@ | |||
| 1 | from setuptools import setup, find_packages | ||
| 2 | |||
| 3 | setup( | ||
| 4 | name='textual-inversion-diff', | ||
| 5 | version='0.0.1', | ||
| 6 | description='', | ||
| 7 | packages=find_packages(), | ||
| 8 | install_requires=[ | ||
| 9 | 'torch', | ||
| 10 | 'numpy', | ||
| 11 | 'tqdm', | ||
| 12 | ], | ||
| 13 | ) | ||
