From 515f0f1fdc9a76bf63bd746c291dcfec7fc747fb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Oct 2022 21:11:53 +0200 Subject: Added support for Aesthetic Gradients --- aesthetic_gradient.py | 137 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 aesthetic_gradient.py (limited to 'aesthetic_gradient.py') diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py new file mode 100644 index 0000000..5386d0f --- /dev/null +++ b/aesthetic_gradient.py @@ -0,0 +1,137 @@ +import argparse +import datetime +import logging +import json +from pathlib import Path + +import torch +import torch.utils.checkpoint +from torchvision import transforms +import pandas as pd + +from accelerate.logging import get_logger +from PIL import Image +from tqdm import tqdm +from transformers import CLIPModel +from slugify import slugify + +logger = get_logger(__name__) + + +torch.backends.cuda.matmul.allow_tf32 = True + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Simple example of a training script." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--train_data_file", + type=str, + default=None, + help="A directory." + ) + parser.add_argument( + "--token", + type=str, + default=None, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--resolution", + type=int, + default=224, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="output/aesthetic-gradient", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." + ) + + args = parser.parse_args() + if args.config is not None: + with open(args.config, 'rt') as f: + args = parser.parse_args( + namespace=argparse.Namespace(**json.load(f)["args"])) + + if args.train_data_file is None: + raise ValueError("You must specify --train_data_file") + + if args.token is None: + raise ValueError("You must specify --token") + + if args.output_dir is None: + raise ValueError("You must specify --output_dir") + + return args + + +def main(): + args = parse_args() + + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + basepath = Path(args.output_dir) + basepath.mkdir(parents=True, exist_ok=True) + target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt") + + logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + + data_file = Path(args.train_data_file) + if not data_file.is_file(): + raise ValueError("data_file must be a file") + data_root = data_file.parent + metadata = pd.read_csv(data_file) + image_paths = [ + data_root.joinpath(item.image) + for item in metadata.itertuples() + if "skip" not in item or item.skip != "x" + ] + + model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS), + transforms.RandomCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + with torch.no_grad(): + embs = [] + for path in tqdm(image_paths): + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = image_transforms(image).unsqueeze(0) + emb = model.get_image_features(image) + print(f">>>> {emb.shape}") + embs.append(emb) + + embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) + + print(embs.shape) + + torch.save(embs, target) + + +if __name__ == "__main__": + main() -- cgit v1.2.3-54-g00ecf