diff options
author | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
commit | 68540b27849564994d921968a36faa9b997e626d (patch) | |
tree | 8fbe834ab4c52f057cd114bbb0e786158f215acc /training | |
parent | Fix training (diff) | |
download | textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2 textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip |
Moved common training code into separate module
Diffstat (limited to 'training')
-rw-r--r-- | training/optimization.py | 2 | ||||
-rw-r--r-- | training/util.py | 131 |
2 files changed, 132 insertions, 1 deletions
diff --git a/training/optimization.py b/training/optimization.py index 0e603fa..c501ed9 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -6,7 +6,7 @@ from diffusers.utils import logging | |||
6 | logger = logging.get_logger(__name__) | 6 | logger = logging.get_logger(__name__) |
7 | 7 | ||
8 | 8 | ||
9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): | 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1): |
10 | """ | 10 | """ |
11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
diff --git a/training/util.py b/training/util.py new file mode 100644 index 0000000..e8d22ae --- /dev/null +++ b/training/util.py | |||
@@ -0,0 +1,131 @@ | |||
1 | from pathlib import Path | ||
2 | import json | ||
3 | |||
4 | import torch | ||
5 | from PIL import Image | ||
6 | |||
7 | |||
8 | def freeze_params(params): | ||
9 | for param in params: | ||
10 | param.requires_grad = False | ||
11 | |||
12 | |||
13 | def save_args(basepath: Path, args, extra={}): | ||
14 | info = {"args": vars(args)} | ||
15 | info["args"].update(extra) | ||
16 | with open(basepath.joinpath("args.json"), "w") as f: | ||
17 | json.dump(info, f, indent=4) | ||
18 | |||
19 | |||
20 | def make_grid(images, rows, cols): | ||
21 | w, h = images[0].size | ||
22 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
23 | for i, image in enumerate(images): | ||
24 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
25 | return grid | ||
26 | |||
27 | |||
28 | class AverageMeter: | ||
29 | def __init__(self, name=None): | ||
30 | self.name = name | ||
31 | self.reset() | ||
32 | |||
33 | def reset(self): | ||
34 | self.sum = self.count = self.avg = 0 | ||
35 | |||
36 | def update(self, val, n=1): | ||
37 | self.sum += val * n | ||
38 | self.count += n | ||
39 | self.avg = self.sum / self.count | ||
40 | |||
41 | |||
42 | class CheckpointerBase: | ||
43 | def __init__( | ||
44 | self, | ||
45 | datamodule, | ||
46 | output_dir: Path, | ||
47 | instance_identifier, | ||
48 | placeholder_token, | ||
49 | placeholder_token_id, | ||
50 | sample_image_size, | ||
51 | sample_batches, | ||
52 | sample_batch_size, | ||
53 | seed | ||
54 | ): | ||
55 | self.datamodule = datamodule | ||
56 | self.output_dir = output_dir | ||
57 | self.instance_identifier = instance_identifier | ||
58 | self.placeholder_token = placeholder_token | ||
59 | self.placeholder_token_id = placeholder_token_id | ||
60 | self.sample_image_size = sample_image_size | ||
61 | self.seed = seed or torch.random.seed() | ||
62 | self.sample_batches = sample_batches | ||
63 | self.sample_batch_size = sample_batch_size | ||
64 | |||
65 | @torch.no_grad() | ||
66 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | ||
67 | samples_path = Path(self.output_dir).joinpath("samples") | ||
68 | |||
69 | train_data = self.datamodule.train_dataloader() | ||
70 | val_data = self.datamodule.val_dataloader() | ||
71 | |||
72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
73 | stable_latents = torch.randn( | ||
74 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
75 | device=pipeline.device, | ||
76 | generator=generator, | ||
77 | ) | ||
78 | |||
79 | with torch.autocast("cuda"), torch.inference_mode(): | ||
80 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
81 | all_samples = [] | ||
82 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
83 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
84 | |||
85 | data_enum = enumerate(data) | ||
86 | |||
87 | batches = [ | ||
88 | batch | ||
89 | for j, batch in data_enum | ||
90 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
91 | ] | ||
92 | prompts = [ | ||
93 | prompt.format(identifier=self.instance_identifier) | ||
94 | for batch in batches | ||
95 | for prompt in batch["prompts"] | ||
96 | ] | ||
97 | nprompts = [ | ||
98 | prompt | ||
99 | for batch in batches | ||
100 | for prompt in batch["nprompts"] | ||
101 | ] | ||
102 | |||
103 | for i in range(self.sample_batches): | ||
104 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
105 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
106 | |||
107 | samples = pipeline( | ||
108 | prompt=prompt, | ||
109 | negative_prompt=nprompt, | ||
110 | height=self.sample_image_size, | ||
111 | width=self.sample_image_size, | ||
112 | image=latents[:len(prompt)] if latents is not None else None, | ||
113 | generator=generator if latents is not None else None, | ||
114 | guidance_scale=guidance_scale, | ||
115 | eta=eta, | ||
116 | num_inference_steps=num_inference_steps, | ||
117 | output_type='pil' | ||
118 | ).images | ||
119 | |||
120 | all_samples += samples | ||
121 | |||
122 | del samples | ||
123 | |||
124 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
125 | image_grid.save(file_path, quality=85) | ||
126 | |||
127 | del all_samples | ||
128 | del image_grid | ||
129 | |||
130 | del generator | ||
131 | del stable_latents | ||