summaryrefslogtreecommitdiffstats
path: root/aesthetic_gradient.py
diff options
context:
space:
mode:
Diffstat (limited to 'aesthetic_gradient.py')
-rw-r--r--aesthetic_gradient.py137
1 files changed, 137 insertions, 0 deletions
diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py
new file mode 100644
index 0000000..5386d0f
--- /dev/null
+++ b/aesthetic_gradient.py
@@ -0,0 +1,137 @@
1import argparse
2import datetime
3import logging
4import json
5from pathlib import Path
6
7import torch
8import torch.utils.checkpoint
9from torchvision import transforms
10import pandas as pd
11
12from accelerate.logging import get_logger
13from PIL import Image
14from tqdm import tqdm
15from transformers import CLIPModel
16from slugify import slugify
17
18logger = get_logger(__name__)
19
20
21torch.backends.cuda.matmul.allow_tf32 = True
22
23
24def parse_args():
25 parser = argparse.ArgumentParser(
26 description="Simple example of a training script."
27 )
28 parser.add_argument(
29 "--pretrained_model_name_or_path",
30 type=str,
31 default=None,
32 help="Path to pretrained model or model identifier from huggingface.co/models.",
33 )
34 parser.add_argument(
35 "--train_data_file",
36 type=str,
37 default=None,
38 help="A directory."
39 )
40 parser.add_argument(
41 "--token",
42 type=str,
43 default=None,
44 help="A token to use as a placeholder for the concept.",
45 )
46 parser.add_argument(
47 "--resolution",
48 type=int,
49 default=224,
50 help=(
51 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
52 " resolution"
53 ),
54 )
55 parser.add_argument(
56 "--output_dir",
57 type=str,
58 default="output/aesthetic-gradient",
59 help="The output directory where the model predictions and checkpoints will be written.",
60 )
61 parser.add_argument(
62 "--config",
63 type=str,
64 default=None,
65 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
66 )
67
68 args = parser.parse_args()
69 if args.config is not None:
70 with open(args.config, 'rt') as f:
71 args = parser.parse_args(
72 namespace=argparse.Namespace(**json.load(f)["args"]))
73
74 if args.train_data_file is None:
75 raise ValueError("You must specify --train_data_file")
76
77 if args.token is None:
78 raise ValueError("You must specify --token")
79
80 if args.output_dir is None:
81 raise ValueError("You must specify --output_dir")
82
83 return args
84
85
86def main():
87 args = parse_args()
88
89 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
90 basepath = Path(args.output_dir)
91 basepath.mkdir(parents=True, exist_ok=True)
92 target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt")
93
94 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
95
96 data_file = Path(args.train_data_file)
97 if not data_file.is_file():
98 raise ValueError("data_file must be a file")
99 data_root = data_file.parent
100 metadata = pd.read_csv(data_file)
101 image_paths = [
102 data_root.joinpath(item.image)
103 for item in metadata.itertuples()
104 if "skip" not in item or item.skip != "x"
105 ]
106
107 model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
108
109 image_transforms = transforms.Compose(
110 [
111 transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS),
112 transforms.RandomCrop(args.resolution),
113 transforms.ToTensor(),
114 transforms.Normalize([0.5], [0.5]),
115 ]
116 )
117
118 with torch.no_grad():
119 embs = []
120 for path in tqdm(image_paths):
121 image = Image.open(path)
122 if not image.mode == "RGB":
123 image = image.convert("RGB")
124 image = image_transforms(image).unsqueeze(0)
125 emb = model.get_image_features(image)
126 print(f">>>> {emb.shape}")
127 embs.append(emb)
128
129 embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
130
131 print(embs.shape)
132
133 torch.save(embs, target)
134
135
136if __name__ == "__main__":
137 main()