diff options
Diffstat (limited to 'infer.py')
| -rw-r--r-- | infer.py | 75 |
1 files changed, 52 insertions, 23 deletions
| @@ -23,7 +23,8 @@ default_args = { | |||
| 23 | "model": None, | 23 | "model": None, |
| 24 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| 25 | "precision": "fp32", | 25 | "precision": "fp32", |
| 26 | "embeddings_dir": "embeddings", | 26 | "ti_embeddings_dir": "embeddings_ti", |
| 27 | "ag_embeddings_dir": "embeddings_ag", | ||
| 27 | "output_dir": "output/inference", | 28 | "output_dir": "output/inference", |
| 28 | "config": None, | 29 | "config": None, |
| 29 | } | 30 | } |
| @@ -73,7 +74,11 @@ def create_args_parser(): | |||
| 73 | choices=["fp32", "fp16", "bf16"], | 74 | choices=["fp32", "fp16", "bf16"], |
| 74 | ) | 75 | ) |
| 75 | parser.add_argument( | 76 | parser.add_argument( |
| 76 | "--embeddings_dir", | 77 | "--ti_embeddings_dir", |
| 78 | type=str, | ||
| 79 | ) | ||
| 80 | parser.add_argument( | ||
| 81 | "--ag_embeddings_dir", | ||
| 77 | type=str, | 82 | type=str, |
| 78 | ) | 83 | ) |
| 79 | parser.add_argument( | 84 | parser.add_argument( |
| @@ -167,42 +172,63 @@ def save_args(basepath, args, extra={}): | |||
| 167 | json.dump(info, f, indent=4) | 172 | json.dump(info, f, indent=4) |
| 168 | 173 | ||
| 169 | 174 | ||
| 170 | def load_embeddings(tokenizer, text_encoder, embeddings_dir): | 175 | def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): |
| 176 | print(f"Loading Textual Inversion embeddings") | ||
| 177 | |||
| 171 | embeddings_dir = Path(embeddings_dir) | 178 | embeddings_dir = Path(embeddings_dir) |
| 172 | embeddings_dir.mkdir(parents=True, exist_ok=True) | 179 | embeddings_dir.mkdir(parents=True, exist_ok=True) |
| 173 | 180 | ||
| 174 | for file in embeddings_dir.iterdir(): | 181 | for file in embeddings_dir.iterdir(): |
| 175 | placeholder_token = file.stem | 182 | if file.is_file(): |
| 183 | placeholder_token = file.stem | ||
| 176 | 184 | ||
| 177 | num_added_tokens = tokenizer.add_tokens(placeholder_token) | 185 | num_added_tokens = tokenizer.add_tokens(placeholder_token) |
| 178 | if num_added_tokens == 0: | 186 | if num_added_tokens == 0: |
| 179 | raise ValueError( | 187 | raise ValueError( |
| 180 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | 188 | f"The tokenizer already contains the token {placeholder_token}. Please pass a different" |
| 181 | " `placeholder_token` that is not already in the tokenizer." | 189 | " `placeholder_token` that is not already in the tokenizer." |
| 182 | ) | 190 | ) |
| 183 | 191 | ||
| 184 | text_encoder.resize_token_embeddings(len(tokenizer)) | 192 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 185 | 193 | ||
| 186 | token_embeds = text_encoder.get_input_embeddings().weight.data | 194 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 187 | 195 | ||
| 188 | for file in embeddings_dir.iterdir(): | 196 | for file in embeddings_dir.iterdir(): |
| 189 | placeholder_token = file.stem | 197 | if file.is_file(): |
| 190 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | 198 | placeholder_token = file.stem |
| 199 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) | ||
| 200 | |||
| 201 | data = torch.load(file, map_location="cpu") | ||
| 202 | |||
| 203 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | ||
| 204 | |||
| 205 | emb = next(iter(data.values())) | ||
| 206 | if len(emb.shape) == 1: | ||
| 207 | emb = emb.unsqueeze(0) | ||
| 191 | 208 | ||
| 192 | data = torch.load(file, map_location="cpu") | 209 | token_embeds[placeholder_token_id] = emb |
| 193 | 210 | ||
| 194 | assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | 211 | print(f"Loaded {placeholder_token}") |
| 195 | 212 | ||
| 196 | emb = next(iter(data.values())) | ||
| 197 | if len(emb.shape) == 1: | ||
| 198 | emb = emb.unsqueeze(0) | ||
| 199 | 213 | ||
| 200 | token_embeds[placeholder_token_id] = emb | 214 | def load_embeddings_ag(pipeline, embeddings_dir): |
| 215 | print(f"Loading Aesthetic Gradient embeddings") | ||
| 201 | 216 | ||
| 202 | print(f"Loaded embedding: {placeholder_token}") | 217 | embeddings_dir = Path(embeddings_dir) |
| 218 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
| 219 | |||
| 220 | for file in embeddings_dir.iterdir(): | ||
| 221 | if file.is_file(): | ||
| 222 | placeholder_token = file.stem | ||
| 203 | 223 | ||
| 224 | data = torch.load(file, map_location="cpu") | ||
| 204 | 225 | ||
| 205 | def create_pipeline(model, scheduler, embeddings_dir, dtype): | 226 | pipeline.add_aesthetic_gradient_embedding(placeholder_token, data) |
| 227 | |||
| 228 | print(f"Loaded {placeholder_token}") | ||
| 229 | |||
| 230 | |||
| 231 | def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype): | ||
| 206 | print("Loading Stable Diffusion pipeline...") | 232 | print("Loading Stable Diffusion pipeline...") |
| 207 | 233 | ||
| 208 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 234 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
| @@ -210,7 +236,7 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype): | |||
| 210 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 236 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) |
| 211 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 237 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 212 | 238 | ||
| 213 | load_embeddings(tokenizer, text_encoder, embeddings_dir) | 239 | load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) |
| 214 | 240 | ||
| 215 | if scheduler == "plms": | 241 | if scheduler == "plms": |
| 216 | scheduler = PNDMScheduler( | 242 | scheduler = PNDMScheduler( |
| @@ -236,10 +262,13 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype): | |||
| 236 | tokenizer=tokenizer, | 262 | tokenizer=tokenizer, |
| 237 | scheduler=scheduler, | 263 | scheduler=scheduler, |
| 238 | ) | 264 | ) |
| 265 | pipeline.aesthetic_gradient_iters = 30 | ||
| 239 | pipeline.to("cuda") | 266 | pipeline.to("cuda") |
| 240 | 267 | ||
| 241 | print("Pipeline loaded.") | 268 | print("Pipeline loaded.") |
| 242 | 269 | ||
| 270 | load_embeddings_ag(pipeline, ag_embeddings_dir) | ||
| 271 | |||
| 243 | return pipeline | 272 | return pipeline |
| 244 | 273 | ||
| 245 | 274 | ||
| @@ -259,7 +288,7 @@ def generate(output_dir, pipeline, args): | |||
| 259 | else: | 288 | else: |
| 260 | init_image = None | 289 | init_image = None |
| 261 | 290 | ||
| 262 | with torch.autocast("cuda"), torch.inference_mode(): | 291 | with torch.autocast("cuda"): |
| 263 | for i in range(args.batch_num): | 292 | for i in range(args.batch_num): |
| 264 | pipeline.set_progress_bar_config( | 293 | pipeline.set_progress_bar_config( |
| 265 | desc=f"Batch {i + 1} of {args.batch_num}", | 294 | desc=f"Batch {i + 1} of {args.batch_num}", |
| @@ -337,7 +366,7 @@ def main(): | |||
| 337 | output_dir = Path(args.output_dir) | 366 | output_dir = Path(args.output_dir) |
| 338 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 367 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 339 | 368 | ||
| 340 | pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype) | 369 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) |
| 341 | cmd_parser = create_cmd_parser() | 370 | cmd_parser = create_cmd_parser() |
| 342 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 371 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
| 343 | cmd_prompt.cmdloop() | 372 | cmd_prompt.cmdloop() |
