1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
from pathlib import Path
import json
import torch
from PIL import Image
def save_args(basepath: Path, args, extra={}):
info = {"args": vars(args)}
info["args"].update(extra)
with open(basepath.joinpath("args.json"), "w") as f:
json.dump(info, f, indent=4)
def make_grid(images, rows, cols):
w, h = images[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, image in enumerate(images):
grid.paste(image, box=(i % cols*w, i//cols*h))
return grid
class AverageMeter:
def __init__(self, name=None):
self.name = name
self.reset()
def reset(self):
self.sum = self.count = self.avg = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class CheckpointerBase:
def __init__(
self,
datamodule,
output_dir: Path,
placeholder_token,
placeholder_token_id,
sample_image_size,
sample_batches,
sample_batch_size,
seed
):
self.datamodule = datamodule
self.output_dir = output_dir
self.placeholder_token = placeholder_token
self.placeholder_token_id = placeholder_token_id
self.sample_image_size = sample_image_size
self.seed = seed or torch.random.seed()
self.sample_batches = sample_batches
self.sample_batch_size = sample_batch_size
@torch.inference_mode()
def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
samples_path = Path(self.output_dir).joinpath("samples")
train_data = self.datamodule.train_dataloader()
val_data = self.datamodule.val_dataloader()
generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
grid_cols = min(self.sample_batch_size, 4)
grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]:
all_samples = []
file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
file_path.parent.mkdir(parents=True, exist_ok=True)
data_enum = enumerate(data)
batches = [
batch
for j, batch in data_enum
if j * data.batch_size < self.sample_batch_size * self.sample_batches
]
prompts = [
prompt
for batch in batches
for prompt in batch["prompts"]
]
nprompts = [
prompt
for batch in batches
for prompt in batch["nprompts"]
]
for i in range(self.sample_batches):
prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
samples = pipeline(
prompt=prompt,
negative_prompt=nprompt,
height=self.sample_image_size,
width=self.sample_image_size,
generator=gen,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
).images
all_samples += samples
del samples
image_grid = make_grid(all_samples, grid_rows, grid_cols)
image_grid.save(file_path, quality=85)
del all_samples
del image_grid
del generator
|