import argparse import datetime import logging import json from pathlib import Path import torch import torch.utils.checkpoint from torchvision import transforms import pandas as pd from accelerate.logging import get_logger from PIL import Image from tqdm import tqdm from transformers import CLIPModel from slugify import slugify logger = get_logger(__name__) torch.backends.cuda.matmul.allow_tf32 = True def parse_args(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--train_data_file", type=str, default=None, help="A directory." ) parser.add_argument( "--token", type=str, default=None, help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--resolution", type=int, default=224, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--output_dir", type=str, default="output/aesthetic-gradient", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--config", type=str, default=None, help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." ) args = parser.parse_args() if args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") if args.token is None: raise ValueError("You must specify --token") if args.output_dir is None: raise ValueError("You must specify --output_dir") return args def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") basepath = Path(args.output_dir) basepath.mkdir(parents=True, exist_ok=True) target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt") logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) data_file = Path(args.train_data_file) if not data_file.is_file(): raise ValueError("data_file must be a file") data_root = data_file.parent metadata = pd.read_csv(data_file) image_paths = [ data_root.joinpath(item.image) for item in metadata.itertuples() if "skip" not in item or item.skip != "x" ] model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS), transforms.RandomCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) with torch.no_grad(): embs = [] for path in tqdm(image_paths): image = Image.open(path) if not image.mode == "RGB": image = image.convert("RGB") image = image_transforms(image).unsqueeze(0) emb = model.get_image_features(image) print(f">>>> {emb.shape}") embs.append(emb) embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) print(embs.shape) torch.save(embs, target) if __name__ == "__main__": main()