summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
commit68540b27849564994d921968a36faa9b997e626d (patch)
tree8fbe834ab4c52f057cd114bbb0e786158f215acc /training/util.py
parentFix training (diff)
downloadtextual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip
Moved common training code into separate module
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py131
1 files changed, 131 insertions, 0 deletions
diff --git a/training/util.py b/training/util.py
new file mode 100644
index 0000000..e8d22ae
--- /dev/null
+++ b/training/util.py
@@ -0,0 +1,131 @@
1from pathlib import Path
2import json
3
4import torch
5from PIL import Image
6
7
8def freeze_params(params):
9 for param in params:
10 param.requires_grad = False
11
12
13def save_args(basepath: Path, args, extra={}):
14 info = {"args": vars(args)}
15 info["args"].update(extra)
16 with open(basepath.joinpath("args.json"), "w") as f:
17 json.dump(info, f, indent=4)
18
19
20def make_grid(images, rows, cols):
21 w, h = images[0].size
22 grid = Image.new('RGB', size=(cols*w, rows*h))
23 for i, image in enumerate(images):
24 grid.paste(image, box=(i % cols*w, i//cols*h))
25 return grid
26
27
28class AverageMeter:
29 def __init__(self, name=None):
30 self.name = name
31 self.reset()
32
33 def reset(self):
34 self.sum = self.count = self.avg = 0
35
36 def update(self, val, n=1):
37 self.sum += val * n
38 self.count += n
39 self.avg = self.sum / self.count
40
41
42class CheckpointerBase:
43 def __init__(
44 self,
45 datamodule,
46 output_dir: Path,
47 instance_identifier,
48 placeholder_token,
49 placeholder_token_id,
50 sample_image_size,
51 sample_batches,
52 sample_batch_size,
53 seed
54 ):
55 self.datamodule = datamodule
56 self.output_dir = output_dir
57 self.instance_identifier = instance_identifier
58 self.placeholder_token = placeholder_token
59 self.placeholder_token_id = placeholder_token_id
60 self.sample_image_size = sample_image_size
61 self.seed = seed or torch.random.seed()
62 self.sample_batches = sample_batches
63 self.sample_batch_size = sample_batch_size
64
65 @torch.no_grad()
66 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
67 samples_path = Path(self.output_dir).joinpath("samples")
68
69 train_data = self.datamodule.train_dataloader()
70 val_data = self.datamodule.val_dataloader()
71
72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
73 stable_latents = torch.randn(
74 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
75 device=pipeline.device,
76 generator=generator,
77 )
78
79 with torch.autocast("cuda"), torch.inference_mode():
80 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
81 all_samples = []
82 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
83 file_path.parent.mkdir(parents=True, exist_ok=True)
84
85 data_enum = enumerate(data)
86
87 batches = [
88 batch
89 for j, batch in data_enum
90 if j * data.batch_size < self.sample_batch_size * self.sample_batches
91 ]
92 prompts = [
93 prompt.format(identifier=self.instance_identifier)
94 for batch in batches
95 for prompt in batch["prompts"]
96 ]
97 nprompts = [
98 prompt
99 for batch in batches
100 for prompt in batch["nprompts"]
101 ]
102
103 for i in range(self.sample_batches):
104 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
105 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
106
107 samples = pipeline(
108 prompt=prompt,
109 negative_prompt=nprompt,
110 height=self.sample_image_size,
111 width=self.sample_image_size,
112 image=latents[:len(prompt)] if latents is not None else None,
113 generator=generator if latents is not None else None,
114 guidance_scale=guidance_scale,
115 eta=eta,
116 num_inference_steps=num_inference_steps,
117 output_type='pil'
118 ).images
119
120 all_samples += samples
121
122 del samples
123
124 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
125 image_grid.save(file_path, quality=85)
126
127 del all_samples
128 del image_grid
129
130 del generator
131 del stable_latents