diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-13 21:11:53 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-13 21:11:53 +0200 |
| commit | 515f0f1fdc9a76bf63bd746c291dcfec7fc747fb (patch) | |
| tree | cf4bbf1cae822cf7c7f388b6918154032def0376 | |
| parent | Added TI+Dreambooth training (diff) | |
| download | textual-inversion-diff-515f0f1fdc9a76bf63bd746c291dcfec7fc747fb.tar.gz textual-inversion-diff-515f0f1fdc9a76bf63bd746c291dcfec7fc747fb.tar.bz2 textual-inversion-diff-515f0f1fdc9a76bf63bd746c291dcfec7fc747fb.zip | |
Added support for Aesthetic Gradients
| -rw-r--r-- | .gitignore | 4 | ||||
| -rw-r--r-- | aesthetic_gradient.py | 137 | ||||
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | dreambooth.py | 10 | ||||
| -rw-r--r-- | dreambooth_plus.py | 16 | ||||
| -rw-r--r-- | infer.py | 75 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 52 | ||||
| -rw-r--r-- | textual_inversion.py | 10 |
8 files changed, 245 insertions, 61 deletions
| @@ -161,5 +161,7 @@ cython_debug/ | |||
| 161 | 161 | ||
| 162 | output/ | 162 | output/ |
| 163 | conf/ | 163 | conf/ |
| 164 | embeddings/ | 164 | embeddings_ti/ |
| 165 | embeddings_ag/ | ||
| 165 | v1-inference.yaml* | 166 | v1-inference.yaml* |
| 167 | *.old | ||
diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py new file mode 100644 index 0000000..5386d0f --- /dev/null +++ b/aesthetic_gradient.py | |||
| @@ -0,0 +1,137 @@ | |||
| 1 | import argparse | ||
| 2 | import datetime | ||
| 3 | import logging | ||
| 4 | import json | ||
| 5 | from pathlib import Path | ||
| 6 | |||
| 7 | import torch | ||
| 8 | import torch.utils.checkpoint | ||
| 9 | from torchvision import transforms | ||
| 10 | import pandas as pd | ||
| 11 | |||
| 12 | from accelerate.logging import get_logger | ||
| 13 | from PIL import Image | ||
| 14 | from tqdm import tqdm | ||
| 15 | from transformers import CLIPModel | ||
| 16 | from slugify import slugify | ||
| 17 | |||
| 18 | logger = get_logger(__name__) | ||
| 19 | |||
| 20 | |||
| 21 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 22 | |||
| 23 | |||
| 24 | def parse_args(): | ||
| 25 | parser = argparse.ArgumentParser( | ||
| 26 | description="Simple example of a training script." | ||
| 27 | ) | ||
| 28 | parser.add_argument( | ||
| 29 | "--pretrained_model_name_or_path", | ||
| 30 | type=str, | ||
| 31 | default=None, | ||
| 32 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
| 33 | ) | ||
| 34 | parser.add_argument( | ||
| 35 | "--train_data_file", | ||
| 36 | type=str, | ||
| 37 | default=None, | ||
| 38 | help="A directory." | ||
| 39 | ) | ||
| 40 | parser.add_argument( | ||
| 41 | "--token", | ||
| 42 | type=str, | ||
| 43 | default=None, | ||
| 44 | help="A token to use as a placeholder for the concept.", | ||
| 45 | ) | ||
| 46 | parser.add_argument( | ||
| 47 | "--resolution", | ||
| 48 | type=int, | ||
| 49 | default=224, | ||
| 50 | help=( | ||
| 51 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
| 52 | " resolution" | ||
| 53 | ), | ||
| 54 | ) | ||
| 55 | parser.add_argument( | ||
| 56 | "--output_dir", | ||
| 57 | type=str, | ||
| 58 | default="output/aesthetic-gradient", | ||
| 59 | help="The output directory where the model predictions and checkpoints will be written.", | ||
| 60 | ) | ||
| 61 | parser.add_argument( | ||
| 62 | "--config", | ||
| 63 | type=str, | ||
| 64 | default=None, | ||
| 65 | help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." | ||
| 66 | ) | ||
| 67 | |||
| 68 | args = parser.parse_args() | ||
| 69 | if args.config is not None: | ||
| 70 | with open(args.config, 'rt') as f: | ||
| 71 | args = parser.parse_args( | ||
| 72 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 73 | |||
| 74 | if args.train_data_file is None: | ||
| 75 | raise ValueError("You must specify --train_data_file") | ||
| 76 | |||
| 77 | if args.token is None: | ||
| 78 | raise ValueError("You must specify --token") | ||
| 79 | |||
| 80 | if args.output_dir is None: | ||
| 81 | raise ValueError("You must specify --output_dir") | ||
| 82 | |||
| 83 | return args | ||
| 84 | |||
| 85 | |||
| 86 | def main(): | ||
| 87 | args = parse_args() | ||
| 88 | |||
| 89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 90 | basepath = Path(args.output_dir) | ||
| 91 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 92 | target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt") | ||
| 93 | |||
| 94 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
| 95 | |||
| 96 | data_file = Path(args.train_data_file) | ||
| 97 | if not data_file.is_file(): | ||
| 98 | raise ValueError("data_file must be a file") | ||
| 99 | data_root = data_file.parent | ||
| 100 | metadata = pd.read_csv(data_file) | ||
| 101 | image_paths = [ | ||
| 102 | data_root.joinpath(item.image) | ||
| 103 | for item in metadata.itertuples() | ||
| 104 | if "skip" not in item or item.skip != "x" | ||
| 105 | ] | ||
| 106 | |||
| 107 | model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 108 | |||
| 109 | image_transforms = transforms.Compose( | ||
| 110 | [ | ||
| 111 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS), | ||
| 112 | transforms.RandomCrop(args.resolution), | ||
| 113 | transforms.ToTensor(), | ||
| 114 | transforms.Normalize([0.5], [0.5]), | ||
| 115 | ] | ||
| 116 | ) | ||
| 117 | |||
| 118 | with torch.no_grad(): | ||
| 119 | embs = [] | ||
| 120 | for path in tqdm(image_paths): | ||
| 121 | image = Image.open(path) | ||
| 122 | if not image.mode == "RGB": | ||
| 123 | image = image.convert("RGB") | ||
| 124 | image = image_transforms(image).unsqueeze(0) | ||
| 125 | emb = model.get_image_features(image) | ||
| 126 | print(f">>>> {emb.shape}") | ||
| 127 | embs.append(emb) | ||
| 128 | |||
| 129 | embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) | ||
| 130 | |||
| 131 | print(embs.shape) | ||
| 132 | |||
| 133 | torch.save(embs, target) | ||
| 134 | |||
| 135 | |||
| 136 | if __name__ == "__main__": | ||
| 137 | main() | ||
diff --git a/data/csv.py b/data/csv.py index 253ce9e..aad970c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -23,7 +23,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 23 | tokenizer, | 23 | tokenizer, |
| 24 | instance_identifier, | 24 | instance_identifier, |
| 25 | class_identifier=None, | 25 | class_identifier=None, |
| 26 | class_subdir="db_cls", | 26 | class_subdir="cls", |
| 27 | num_class_images=100, | 27 | num_class_images=100, |
| 28 | size=512, | 28 | size=512, |
| 29 | repeats=100, | 29 | repeats=100, |
diff --git a/dreambooth.py b/dreambooth.py index 699313e..072142e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -216,12 +216,6 @@ def parse_args(): | |||
| 216 | ), | 216 | ), |
| 217 | ) | 217 | ) |
| 218 | parser.add_argument( | 218 | parser.add_argument( |
| 219 | "--local_rank", | ||
| 220 | type=int, | ||
| 221 | default=-1, | ||
| 222 | help="For distributed training: local_rank" | ||
| 223 | ) | ||
| 224 | parser.add_argument( | ||
| 225 | "--sample_frequency", | 219 | "--sample_frequency", |
| 226 | type=int, | 220 | type=int, |
| 227 | default=100, | 221 | default=100, |
| @@ -287,10 +281,6 @@ def parse_args(): | |||
| 287 | args = parser.parse_args( | 281 | args = parser.parse_args( |
| 288 | namespace=argparse.Namespace(**json.load(f)["args"])) | 282 | namespace=argparse.Namespace(**json.load(f)["args"])) |
| 289 | 283 | ||
| 290 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
| 291 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
| 292 | args.local_rank = env_local_rank | ||
| 293 | |||
| 294 | if args.train_data_file is None: | 284 | if args.train_data_file is None: |
| 295 | raise ValueError("You must specify --train_data_file") | 285 | raise ValueError("You must specify --train_data_file") |
| 296 | 286 | ||
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 9e482b3..7996bc2 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--max_train_steps", | 113 | "--max_train_steps", |
| 114 | type=int, | 114 | type=int, |
| 115 | default=3000, | 115 | default=1600, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -129,13 +129,13 @@ def parse_args(): | |||
| 129 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--learning_rate_unet", | 130 | "--learning_rate_unet", |
| 131 | type=float, | 131 | type=float, |
| 132 | default=1e-5, | 132 | default=5e-6, |
| 133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| 136 | "--learning_rate_text", | 136 | "--learning_rate_text", |
| 137 | type=float, | 137 | type=float, |
| 138 | default=1e-4, | 138 | default=5e-4, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 139 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 140 | ) |
| 141 | parser.add_argument( | 141 | parser.add_argument( |
| @@ -222,12 +222,6 @@ def parse_args(): | |||
| 222 | ), | 222 | ), |
| 223 | ) | 223 | ) |
| 224 | parser.add_argument( | 224 | parser.add_argument( |
| 225 | "--local_rank", | ||
| 226 | type=int, | ||
| 227 | default=-1, | ||
| 228 | help="For distributed training: local_rank" | ||
| 229 | ) | ||
| 230 | parser.add_argument( | ||
| 231 | "--sample_frequency", | 225 | "--sample_frequency", |
| 232 | type=int, | 226 | type=int, |
| 233 | default=100, | 227 | default=100, |
| @@ -293,10 +287,6 @@ def parse_args(): | |||
| 293 | args = parser.parse_args( | 287 | args = parser.parse_args( |
| 294 | namespace=argparse.Namespace(**json.load(f)["args"])) | 288 | namespace=argparse.Namespace(**json.load(f)["args"])) |
| 295 | 289 | ||
| 296 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
| 297 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
| 298 | args.local_rank = env_local_rank | ||
| 299 | |||
| 300 | if args.train_data_file is None: | 290 | if args.train_data_file is None: |
| 301 | raise ValueError("You must specify --train_data_file") | 291 | raise ValueError("You must specify --train_data_file") |
| 302 | 292 | ||
| @@ -23,7 +23,8 @@ default_args = { | |||
| 23 | "model": None, | 23 | "model": None, |
| 24 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| 25 | "precision": "fp32", | 25 | "precision": "fp32", |
| 26 | "embeddings_dir": "embeddings", | 26 | "ti_embeddings_dir": "embeddings_ti", |
| 27 | "ag_embeddings_dir": "embeddings_ag", | ||
| 27 | "output_dir": "output/inference", | 28 | "output_dir": "output/inference", |
| 28 | "config": None, | 29 | "config": None, |
| 29 | } | 30 | } |
| @@ -73,7 +74,11 @@ def create_args_parser(): | |||
| 73 | choices=["fp32", "fp16", "bf16"], | 74 | choices=["fp32", "fp16", "bf16"], |
| 74 | ) | 75 | ) |
| 75 | parser.add_argument( | 76 | parser.add_argument( |
| 76 | "--embeddings_dir", | 77 | "--ti_embeddings_dir", |
| 78 | type=str, | ||
| 79 | ) | ||
| 80 | parser.add_argument( | ||
| 81 | "--ag_embeddings_dir", | ||
| 77 | type=str, | 82 | type=str, |
| 78 | ) | 83 | ) |
| 79 | parser.add_argument( | 84 | parser.add_argument( |
| @@ -167,42 +172,63 @@ def save_args(basepath, args, extra={}): | |||
| 167 | json.dump(info, f, indent=4) | 172 | json.dump(info, f, indent=4) |
| 168 | 173 | ||
| 169 | 174 | ||
| 170 | def load_embeddings(tokenizer, text_encoder, embeddings_dir): | 175 | def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): |
| 176 | print(f"Loading Textual Inversion embeddings") | ||
| 177 | |||
| 171 | embeddings_dir = Path(embeddings_dir) | 178 | embeddings_dir = Path(embeddings_dir) |
| 172 | embeddings_dir.mkdir(parents=True, exist_ok=True) | 179 | embeddings_dir.mkdir(parents=True, exist_ok=True) |
| 173 | 180 | ||
| 174 | for file in embeddings_dir.iterdir(): | 181 | for file in embeddings_dir.iterdir(): |
| 175 | placeholder_token = file.stem | 182 | if file.is_file(): |
| 183 | placeholder_token = file.stem | ||
| 176 | 184 | ||
| 177 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | 185 | num_added_tokens = tokenizer.add_tokens(placeholder_token) |
| 178 | if num_added_tokens == 0: | 186 | if num_added_tokens == 0: |
| 179 | raise ValueError( | 187 | raise ValueError( |
| 180 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | 188 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" |
| 181 | " `placeholder_token` that is not already in the tokenizer." | 189 | " `placeholder_token` that is not already in the tokenizer." |
| 182 | ) | 190 | ) |
| 183 | 191 | ||
| 184 | text_encoder.resize_token_embeddings(len(tokenizer)) | 192 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 185 | 193 | ||
| 186 | token_embeds = text_encoder.get_input_embeddings().weight.data | 194 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 187 | 195 | ||
| 188 | for file in embeddings_dir.iterdir(): | 196 | for file in embeddings_dir.iterdir(): |
| 189 | placeholder_token = file.stem | 197 | if file.is_file(): |
| 190 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | 198 | placeholder_token = file.stem |
| 199 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | ||
| 200 | |||
| 201 | data = torch.load(file, map_location="cpu") | ||
| 202 | |||
| 203 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | ||
| 204 | |||
| 205 | emb = next(iter(data.values())) | ||
| 206 | if len(emb.shape) == 1: | ||
| 207 | emb = emb.unsqueeze(0) | ||
| 191 | 208 | ||
| 192 | data = torch.load(file, map_location="cpu") | 209 | token_embeds[placeholder_token_id] = emb |
| 193 | 210 | ||
| 194 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | 211 | print(f"Loaded {placeholder_token}") |
| 195 | 212 | ||
| 196 | emb = next(iter(data.values())) | ||
| 197 | if len(emb.shape) == 1: | ||
| 198 | emb = emb.unsqueeze(0) | ||
| 199 | 213 | ||
| 200 | token_embeds[placeholder_token_id] = emb | 214 | def load_embeddings_ag(pipeline, embeddings_dir): |
| 215 | print(f"Loading Aesthetic Gradient embeddings") | ||
| 201 | 216 | ||
| 202 | print(f"Loaded embedding: {placeholder_token}") | 217 | embeddings_dir = Path(embeddings_dir) |
| 218 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
| 219 | |||
| 220 | for file in embeddings_dir.iterdir(): | ||
| 221 | if file.is_file(): | ||
| 222 | placeholder_token = file.stem | ||
| 203 | 223 | ||
| 224 | data = torch.load(file, map_location="cpu") | ||
| 204 | 225 | ||
| 205 | def create_pipeline(model, scheduler, embeddings_dir, dtype): | 226 | pipeline.add_aesthetic_gradient_embedding(placeholder_token, data) |
| 227 | |||
| 228 | print(f"Loaded {placeholder_token}") | ||
| 229 | |||
| 230 | |||
| 231 | def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype): | ||
| 206 | print("Loading Stable Diffusion pipeline...") | 232 | print("Loading Stable Diffusion pipeline...") |
| 207 | 233 | ||
| 208 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 234 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
| @@ -210,7 +236,7 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype): | |||
| 210 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 236 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
| 211 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 237 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 212 | 238 | ||
| 213 | load_embeddings(tokenizer, text_encoder, embeddings_dir) | 239 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) |
| 214 | 240 | ||
| 215 | if scheduler == "plms": | 241 | if scheduler == "plms": |
| 216 | scheduler = PNDMScheduler( | 242 | scheduler = PNDMScheduler( |
| @@ -236,10 +262,13 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype): | |||
| 236 | tokenizer=tokenizer, | 262 | tokenizer=tokenizer, |
| 237 | scheduler=scheduler, | 263 | scheduler=scheduler, |
| 238 | ) | 264 | ) |
| 265 | pipeline.aesthetic_gradient_iters = 30 | ||
| 239 | pipeline.to("cuda") | 266 | pipeline.to("cuda") |
| 240 | 267 | ||
| 241 | print("Pipeline loaded.") | 268 | print("Pipeline loaded.") |
| 242 | 269 | ||
| 270 | load_embeddings_ag(pipeline, ag_embeddings_dir) | ||
| 271 | |||
| 243 | return pipeline | 272 | return pipeline |
| 244 | 273 | ||
| 245 | 274 | ||
| @@ -259,7 +288,7 @@ def generate(output_dir, pipeline, args): | |||
| 259 | else: | 288 | else: |
| 260 | init_image = None | 289 | init_image = None |
| 261 | 290 | ||
| 262 | with torch.autocast("cuda"), torch.inference_mode(): | 291 | with torch.autocast("cuda"): |
| 263 | for i in range(args.batch_num): | 292 | for i in range(args.batch_num): |
| 264 | pipeline.set_progress_bar_config( | 293 | pipeline.set_progress_bar_config( |
| 265 | desc=f"Batch {i + 1} of {args.batch_num}", | 294 | desc=f"Batch {i + 1} of {args.batch_num}", |
| @@ -337,7 +366,7 @@ def main(): | |||
| 337 | output_dir = Path(args.output_dir) | 366 | output_dir = Path(args.output_dir) |
| 338 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 367 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 339 | 368 | ||
| 340 | pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype) | 369 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) |
| 341 | cmd_parser = create_cmd_parser() | 370 | cmd_parser = create_cmd_parser() |
| 342 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 371 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
| 343 | cmd_prompt.cmdloop() | 372 | cmd_prompt.cmdloop() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 8927a78..1a84c8d 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -4,13 +4,14 @@ from typing import List, Optional, Union | |||
| 4 | 4 | ||
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | import torch | 6 | import torch |
| 7 | import torch.optim as optim | ||
| 7 | import PIL | 8 | import PIL |
| 8 | 9 | ||
| 9 | from diffusers.configuration_utils import FrozenDict | 10 | from diffusers.configuration_utils import FrozenDict |
| 10 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 11 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel |
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 12 | from diffusers.utils import logging | 13 | from diffusers.utils import logging |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer | 14 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel |
| 14 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 15 | 16 | ||
| 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 17 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| @@ -50,6 +51,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 50 | new_config["steps_offset"] = 1 | 51 | new_config["steps_offset"] = 1 |
| 51 | scheduler._internal_dict = FrozenDict(new_config) | 52 | scheduler._internal_dict = FrozenDict(new_config) |
| 52 | 53 | ||
| 54 | self.aesthetic_gradient_embeddings = {} | ||
| 55 | self.aesthetic_gradient_lr = 1e-4 | ||
| 56 | self.aesthetic_gradient_iters = 10 | ||
| 57 | |||
| 53 | self.register_modules( | 58 | self.register_modules( |
| 54 | vae=vae, | 59 | vae=vae, |
| 55 | text_encoder=text_encoder, | 60 | text_encoder=text_encoder, |
| @@ -58,6 +63,47 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 58 | scheduler=scheduler, | 63 | scheduler=scheduler, |
| 59 | ) | 64 | ) |
| 60 | 65 | ||
| 66 | def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): | ||
| 67 | self.aesthetic_gradient_embeddings[keyword] = tensor | ||
| 68 | |||
| 69 | def get_text_embeddings(self, prompt, text_input_ids): | ||
| 70 | prompt = " ".join(prompt) | ||
| 71 | |||
| 72 | embeddings = [ | ||
| 73 | embedding | ||
| 74 | for key, embedding in self.aesthetic_gradient_embeddings.items() | ||
| 75 | if key in prompt | ||
| 76 | ] | ||
| 77 | |||
| 78 | if len(embeddings) != 0: | ||
| 79 | with torch.enable_grad(): | ||
| 80 | full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 81 | full_clip_model.to(self.device) | ||
| 82 | full_clip_model.text_model.train() | ||
| 83 | |||
| 84 | optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr) | ||
| 85 | |||
| 86 | for embs in embeddings: | ||
| 87 | embs = embs.clone().detach().to(self.device) | ||
| 88 | embs /= embs.norm(dim=-1, keepdim=True) | ||
| 89 | |||
| 90 | for i in range(self.aesthetic_gradient_iters): | ||
| 91 | text_embs = full_clip_model.get_text_features(text_input_ids) | ||
| 92 | text_embs /= text_embs.norm(dim=-1, keepdim=True) | ||
| 93 | sim = text_embs @ embs.T | ||
| 94 | loss = -sim | ||
| 95 | loss = loss.mean() | ||
| 96 | |||
| 97 | loss.backward() | ||
| 98 | optimizer.step() | ||
| 99 | optimizer.zero_grad() | ||
| 100 | |||
| 101 | full_clip_model.text_model.eval() | ||
| 102 | |||
| 103 | return full_clip_model.text_model(text_input_ids)[0] | ||
| 104 | else: | ||
| 105 | return self.text_encoder(text_input_ids)[0] | ||
| 106 | |||
| 61 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 107 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 62 | r""" | 108 | r""" |
| 63 | Enable sliced attention computation. | 109 | Enable sliced attention computation. |
| @@ -195,7 +241,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 195 | ) | 241 | ) |
| 196 | print(f"Too many tokens: {removed_text}") | 242 | print(f"Too many tokens: {removed_text}") |
| 197 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 243 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
| 198 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | 244 | text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) |
| 199 | 245 | ||
| 200 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 246 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 201 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 247 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| @@ -207,7 +253,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 207 | uncond_input = self.tokenizer( | 253 | uncond_input = self.tokenizer( |
| 208 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | 254 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" |
| 209 | ) | 255 | ) |
| 210 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | 256 | uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) |
| 211 | 257 | ||
| 212 | # For classifier free guidance, we need to do two forward passes. | 258 | # For classifier free guidance, we need to do two forward passes. |
| 213 | # Here we concatenate the unconditional and text embeddings into a single batch | 259 | # Here we concatenate the unconditional and text embeddings into a single batch |
diff --git a/textual_inversion.py b/textual_inversion.py index 181a318..9d2840d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -193,12 +193,6 @@ def parse_args(): | |||
| 193 | ), | 193 | ), |
| 194 | ) | 194 | ) |
| 195 | parser.add_argument( | 195 | parser.add_argument( |
| 196 | "--local_rank", | ||
| 197 | type=int, | ||
| 198 | default=-1, | ||
| 199 | help="For distributed training: local_rank" | ||
| 200 | ) | ||
| 201 | parser.add_argument( | ||
| 202 | "--checkpoint_frequency", | 196 | "--checkpoint_frequency", |
| 203 | type=int, | 197 | type=int, |
| 204 | default=500, | 198 | default=500, |
| @@ -280,10 +274,6 @@ def parse_args(): | |||
| 280 | args = parser.parse_args( | 274 | args = parser.parse_args( |
| 281 | namespace=argparse.Namespace(**json.load(f)["args"])) | 275 | namespace=argparse.Namespace(**json.load(f)["args"])) |
| 282 | 276 | ||
| 283 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
| 284 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
| 285 | args.local_rank = env_local_rank | ||
| 286 | |||
| 287 | if args.train_data_file is None: | 277 | if args.train_data_file is None: |
| 288 | raise ValueError("You must specify --train_data_file") | 278 | raise ValueError("You must specify --train_data_file") |
| 289 | 279 | ||
