From 3396ca881ed3f3521617cd9024eea56975191d32 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 13:26:32 +0100 Subject: Update --- data/csv.py | 22 +--------------------- data/keywords.py | 21 +++++++++++++++++++++ infer.py | 40 +++++++++++++++++++++++++++++----------- models/clip/prompt.py | 10 ++++++---- train_dreambooth.py | 14 +++++++------- train_ti.py | 14 +++++++------- training/common.py | 5 ++++- 7 files changed, 75 insertions(+), 51 deletions(-) create mode 100644 data/keywords.py diff --git a/data/csv.py b/data/csv.py index a60733a..d1f3054 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,7 +1,6 @@ import math import torch import json -import numpy as np from pathlib import Path from PIL import Image from torch.utils.data import Dataset, DataLoader, random_split @@ -9,32 +8,13 @@ from torchvision import transforms from typing import Dict, NamedTuple, List, Optional, Union, Callable from models.clip.prompt import PromptProcessor +from data.keywords import prompt_to_keywords, keywords_to_prompt def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: - if dropout != 0: - prompt = [keyword for keyword in prompt if np.random.random() > dropout] - if shuffle: - np.random.shuffle(prompt) - return ", ".join(prompt) - - -def prompt_to_keywords(prompt: str, expansions: dict[str, str]) -> list[str]: - def expand_keyword(keyword: str) -> list[str]: - return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] - - return [ - kw - for keyword in prompt.split(", ") - for kw in expand_keyword(keyword) - if keyword != "" - ] - - class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path diff --git a/data/keywords.py b/data/keywords.py new file mode 100644 index 0000000..9e656f3 --- /dev/null +++ b/data/keywords.py @@ -0,0 +1,21 @@ +import numpy as np + + +def keywords_to_prompt(prompt: list[str], dropout: float = 0, shuffle: bool = False) -> str: + if dropout != 0: + prompt = [keyword for keyword in prompt if np.random.random() > dropout] + if shuffle: + np.random.shuffle(prompt) + return ", ".join(prompt) + + +def prompt_to_keywords(prompt: str, expansions: dict[str, str] = {}) -> list[str]: + def expand_keyword(keyword: str) -> list[str]: + return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] + + return [ + kw + for keyword in prompt.split(", ") + for kw in expand_keyword(keyword) + if keyword != "" + ] diff --git a/infer.py b/infer.py index 507d0cf..9c27db4 100644 --- a/infer.py +++ b/infer.py @@ -25,6 +25,7 @@ from diffusers import ( ) from transformers import CLIPTextModel +from data.keywords import prompt_to_keywords, keywords_to_prompt from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion @@ -49,6 +50,7 @@ default_cmds = { "scheduler": "dpmsm", "prompt": None, "negative_prompt": None, + "shuffle": True, "image": None, "image_noise": .7, "width": 768, @@ -125,6 +127,10 @@ def create_cmd_parser(): type=str, nargs="*", ) + parser.add_argument( + "--shuffle", + type=bool, + ) parser.add_argument( "--image", type=str, @@ -197,7 +203,7 @@ def load_embeddings(pipeline, embeddings_dir): pipeline.text_encoder.text_model.embeddings, Path(embeddings_dir) ) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") + print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") def create_pipeline(model, dtype): @@ -228,20 +234,35 @@ def create_pipeline(model, dtype): @torch.inference_mode() -def generate(output_dir, pipeline, args): +def generate(output_dir: Path, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] + if args.shuffle: + args.prompt *= args.batch_size + args.batch_size = 1 + args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - use_subdirs = len(args.prompt) != 1 - if use_subdirs: + image_dir = [] + + if len(args.prompt) != 1: if len(args.project) != 0: output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") else: output_dir = output_dir.joinpath(now) + + for prompt in args.prompt: + dir = output_dir.joinpath(slugify(prompt)[:100]) + dir.mkdir(parents=True, exist_ok=True) + image_dir.append(dir) + + with open(dir.joinpath('prompt.txt'), 'w') as f: + f.write(prompt) else: output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") - output_dir.mkdir(parents=True, exist_ok=True) + output_dir.mkdir(parents=True, exist_ok=True) + image_dir.append(output_dir) args.seed = args.seed or torch.random.seed() @@ -293,12 +314,9 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image_dir = output_dir - if use_subdirs: - image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) - image_dir.mkdir(parents=True, exist_ok=True) - image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) - image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) + dir = image_dir[j % len(args.prompt)] + image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) + image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/models/clip/prompt.py b/models/clip/prompt.py index 9da3955..a7380be 100644 --- a/models/clip/prompt.py +++ b/models/clip/prompt.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional import torch @@ -16,7 +16,7 @@ class PromptProcessor(): padding="do_not_pad", ).input_ids - def unify_input_ids(self, input_ids: list[int]): + def unify_input_ids(self, input_ids: list[list[int]]): return self.tokenizer.pad( {"input_ids": input_ids}, padding=True, @@ -24,13 +24,15 @@ class PromptProcessor(): return_tensors="pt" ) - def get_embeddings(self, input_ids: torch.IntTensor, attention_mask=None): + def get_embeddings(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, attention_mask=None): prompts = input_ids.shape[0] input_ids = input_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) + if position_ids is not None: + position_ids = position_ids.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) if attention_mask is not None: attention_mask = attention_mask.view((-1, self.tokenizer.model_max_length)).to(self.text_encoder.device) - text_embeddings = self.text_encoder(input_ids, attention_mask=attention_mask)[0] + text_embeddings = self.text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings diff --git a/train_dreambooth.py b/train_dreambooth.py index 1fd86b1..4d1e0a3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -861,13 +861,13 @@ def main(): loop = partial( run_model, - vae=vae, - noise_scheduler=noise_scheduler, - unet=unet, - prompt_processor=prompt_processor, - num_class_images=args.num_class_images, - prior_loss_weight=args.prior_loss_weight, - seed=args.seed, + vae, + noise_scheduler, + unet, + prompt_processor, + args.num_class_images, + args.prior_loss_weight, + args.seed, ) # We need to initialize the trackers we use, and also store our configuration. diff --git a/train_ti.py b/train_ti.py index 164cf67..98385dd 100644 --- a/train_ti.py +++ b/train_ti.py @@ -814,13 +814,13 @@ def main(): loop = partial( run_model, - vae=vae, - noise_scheduler=noise_scheduler, - unet=unet, - prompt_processor=prompt_processor, - num_class_images=args.num_class_images, - prior_loss_weight=args.prior_loss_weight, - seed=args.seed, + vae, + noise_scheduler, + unet, + prompt_processor, + args.num_class_images, + args.prior_loss_weight, + args.seed, ) # We need to initialize the trackers we use, and also store our configuration. diff --git a/training/common.py b/training/common.py index 99a6e67..ab2741a 100644 --- a/training/common.py +++ b/training/common.py @@ -40,7 +40,10 @@ def run_model( noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning - encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = prompt_processor.get_embeddings( + batch["input_ids"], + batch["attention_mask"] + ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual -- cgit v1.2.3-70-g09d2