summaryrefslogtreecommitdiffstats
path: root/aesthetic_gradient.py
blob: 5386d0f6237fd76eff8ebfabe43cc09513a7248d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import datetime
import logging
import json
from pathlib import Path

import torch
import torch.utils.checkpoint
from torchvision import transforms
import pandas as pd

from accelerate.logging import get_logger
from PIL import Image
from tqdm import tqdm
from transformers import CLIPModel
from slugify import slugify

logger = get_logger(__name__)


torch.backends.cuda.matmul.allow_tf32 = True


def parse_args():
    parser = argparse.ArgumentParser(
        description="Simple example of a training script."
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--train_data_file",
        type=str,
        default=None,
        help="A directory."
    )
    parser.add_argument(
        "--token",
        type=str,
        default=None,
        help="A token to use as a placeholder for the concept.",
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=224,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output/aesthetic-gradient",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
    )

    args = parser.parse_args()
    if args.config is not None:
        with open(args.config, 'rt') as f:
            args = parser.parse_args(
                namespace=argparse.Namespace(**json.load(f)["args"]))

    if args.train_data_file is None:
        raise ValueError("You must specify --train_data_file")

    if args.token is None:
        raise ValueError("You must specify --token")

    if args.output_dir is None:
        raise ValueError("You must specify --output_dir")

    return args


def main():
    args = parse_args()

    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    basepath = Path(args.output_dir)
    basepath.mkdir(parents=True, exist_ok=True)
    target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt")

    logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)

    data_file = Path(args.train_data_file)
    if not data_file.is_file():
        raise ValueError("data_file must be a file")
    data_root = data_file.parent
    metadata = pd.read_csv(data_file)
    image_paths = [
        data_root.joinpath(item.image)
        for item in metadata.itertuples()
        if "skip" not in item or item.skip != "x"
    ]

    model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")

    image_transforms = transforms.Compose(
        [
            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS),
            transforms.RandomCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    with torch.no_grad():
        embs = []
        for path in tqdm(image_paths):
            image = Image.open(path)
            if not image.mode == "RGB":
                image = image.convert("RGB")
            image = image_transforms(image).unsqueeze(0)
            emb = model.get_image_features(image)
            print(f">>>> {emb.shape}")
            embs.append(emb)

        embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)

        print(embs.shape)

        torch.save(embs, target)


if __name__ == "__main__":
    main()