diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
commit | 83808fe00ac891ad2f625388d144c318b2cb5bfe (patch) | |
tree | b7ca19d27f90be6f02b14f4a39c62fc7250041a2 /training/util.py | |
parent | TI: Prepare UNet with Accelerate as well (diff) | |
download | textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2 textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip |
WIP: Modularization ("free(): invalid pointer" my ass)
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 214 |
1 files changed, 118 insertions, 96 deletions
diff --git a/training/util.py b/training/util.py index 781cf04..a292edd 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,12 +1,40 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | import itertools | 4 | from typing import Iterable, Union |
5 | from typing import Iterable, Optional | ||
6 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
7 | 6 | ||
8 | import torch | 7 | import torch |
9 | from PIL import Image | 8 | |
9 | from transformers import CLIPTextModel | ||
10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
11 | |||
12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
15 | |||
16 | |||
17 | class TrainingStrategy(): | ||
18 | @property | ||
19 | def main_model(self) -> torch.nn.Module: | ||
20 | ... | ||
21 | |||
22 | @contextmanager | ||
23 | def on_train(self, epoch: int): | ||
24 | yield | ||
25 | |||
26 | @contextmanager | ||
27 | def on_eval(self): | ||
28 | yield | ||
29 | |||
30 | def on_before_optimize(self, epoch: int): | ||
31 | ... | ||
32 | |||
33 | def on_after_optimize(self, lr: float): | ||
34 | ... | ||
35 | |||
36 | def on_log(): | ||
37 | return {} | ||
10 | 38 | ||
11 | 39 | ||
12 | def save_args(basepath: Path, args, extra={}): | 40 | def save_args(basepath: Path, args, extra={}): |
@@ -16,12 +44,93 @@ def save_args(basepath: Path, args, extra={}): | |||
16 | json.dump(info, f, indent=4) | 44 | json.dump(info, f, indent=4) |
17 | 45 | ||
18 | 46 | ||
19 | def make_grid(images, rows, cols): | 47 | def generate_class_images( |
20 | w, h = images[0].size | 48 | accelerator, |
21 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 49 | text_encoder, |
22 | for i, image in enumerate(images): | 50 | vae, |
23 | grid.paste(image, box=(i % cols*w, i//cols*h)) | 51 | unet, |
24 | return grid | 52 | tokenizer, |
53 | scheduler, | ||
54 | data_train, | ||
55 | sample_batch_size, | ||
56 | sample_image_size, | ||
57 | sample_steps | ||
58 | ): | ||
59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
60 | |||
61 | if len(missing_data) == 0: | ||
62 | return | ||
63 | |||
64 | batched_data = [ | ||
65 | missing_data[i:i+sample_batch_size] | ||
66 | for i in range(0, len(missing_data), sample_batch_size) | ||
67 | ] | ||
68 | |||
69 | pipeline = VlpnStableDiffusion( | ||
70 | text_encoder=text_encoder, | ||
71 | vae=vae, | ||
72 | unet=unet, | ||
73 | tokenizer=tokenizer, | ||
74 | scheduler=scheduler, | ||
75 | ).to(accelerator.device) | ||
76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
77 | |||
78 | with torch.inference_mode(): | ||
79 | for batch in batched_data: | ||
80 | image_name = [item.class_image_path for item in batch] | ||
81 | prompt = [item.cprompt for item in batch] | ||
82 | nprompt = [item.nprompt for item in batch] | ||
83 | |||
84 | images = pipeline( | ||
85 | prompt=prompt, | ||
86 | negative_prompt=nprompt, | ||
87 | height=sample_image_size, | ||
88 | width=sample_image_size, | ||
89 | num_inference_steps=sample_steps | ||
90 | ).images | ||
91 | |||
92 | for i, image in enumerate(images): | ||
93 | image.save(image_name[i]) | ||
94 | |||
95 | del pipeline | ||
96 | |||
97 | if torch.cuda.is_available(): | ||
98 | torch.cuda.empty_cache() | ||
99 | |||
100 | |||
101 | def get_models(pretrained_model_name_or_path: str): | ||
102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
109 | |||
110 | embeddings = patch_managed_embeddings(text_encoder) | ||
111 | |||
112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
113 | |||
114 | |||
115 | def add_placeholder_tokens( | ||
116 | tokenizer: MultiCLIPTokenizer, | ||
117 | embeddings: ManagedCLIPTextEmbeddings, | ||
118 | placeholder_tokens: list[str], | ||
119 | initializer_tokens: list[str], | ||
120 | num_vectors: Union[list[int], int] | ||
121 | ): | ||
122 | initializer_token_ids = [ | ||
123 | tokenizer.encode(token, add_special_tokens=False) | ||
124 | for token in initializer_tokens | ||
125 | ] | ||
126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
127 | |||
128 | embeddings.resize(len(tokenizer)) | ||
129 | |||
130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
132 | |||
133 | return placeholder_token_ids, initializer_token_ids | ||
25 | 134 | ||
26 | 135 | ||
27 | class AverageMeter: | 136 | class AverageMeter: |
@@ -38,93 +147,6 @@ class AverageMeter: | |||
38 | self.avg = self.sum / self.count | 147 | self.avg = self.sum / self.count |
39 | 148 | ||
40 | 149 | ||
41 | class CheckpointerBase: | ||
42 | def __init__( | ||
43 | self, | ||
44 | train_dataloader, | ||
45 | val_dataloader, | ||
46 | output_dir: Path, | ||
47 | sample_steps: int = 20, | ||
48 | sample_guidance_scale: float = 7.5, | ||
49 | sample_image_size: int = 768, | ||
50 | sample_batches: int = 1, | ||
51 | sample_batch_size: int = 1, | ||
52 | seed: Optional[int] = None | ||
53 | ): | ||
54 | self.train_dataloader = train_dataloader | ||
55 | self.val_dataloader = val_dataloader | ||
56 | self.output_dir = output_dir | ||
57 | self.sample_image_size = sample_image_size | ||
58 | self.sample_steps = sample_steps | ||
59 | self.sample_guidance_scale = sample_guidance_scale | ||
60 | self.sample_batches = sample_batches | ||
61 | self.sample_batch_size = sample_batch_size | ||
62 | self.seed = seed if seed is not None else torch.random.seed() | ||
63 | |||
64 | @torch.no_grad() | ||
65 | def checkpoint(self, step: int, postfix: str): | ||
66 | pass | ||
67 | |||
68 | @torch.inference_mode() | ||
69 | def save_samples(self, pipeline, step: int): | ||
70 | samples_path = Path(self.output_dir).joinpath("samples") | ||
71 | |||
72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
73 | |||
74 | grid_cols = min(self.sample_batch_size, 4) | ||
75 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | ||
76 | |||
77 | for pool, data, gen in [ | ||
78 | ("stable", self.val_dataloader, generator), | ||
79 | ("val", self.val_dataloader, None), | ||
80 | ("train", self.train_dataloader, None) | ||
81 | ]: | ||
82 | all_samples = [] | ||
83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
84 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
85 | |||
86 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | ||
87 | prompt_ids = [ | ||
88 | prompt | ||
89 | for batch in batches | ||
90 | for prompt in batch["prompt_ids"] | ||
91 | ] | ||
92 | nprompt_ids = [ | ||
93 | prompt | ||
94 | for batch in batches | ||
95 | for prompt in batch["nprompt_ids"] | ||
96 | ] | ||
97 | |||
98 | for i in range(self.sample_batches): | ||
99 | start = i * self.sample_batch_size | ||
100 | end = (i + 1) * self.sample_batch_size | ||
101 | prompt = prompt_ids[start:end] | ||
102 | nprompt = nprompt_ids[start:end] | ||
103 | |||
104 | samples = pipeline( | ||
105 | prompt=prompt, | ||
106 | negative_prompt=nprompt, | ||
107 | height=self.sample_image_size, | ||
108 | width=self.sample_image_size, | ||
109 | generator=gen, | ||
110 | guidance_scale=self.sample_guidance_scale, | ||
111 | num_inference_steps=self.sample_steps, | ||
112 | output_type='pil' | ||
113 | ).images | ||
114 | |||
115 | all_samples += samples | ||
116 | |||
117 | del samples | ||
118 | |||
119 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | ||
120 | image_grid.save(file_path, quality=85) | ||
121 | |||
122 | del all_samples | ||
123 | del image_grid | ||
124 | |||
125 | del generator | ||
126 | |||
127 | |||
128 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 150 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 |
129 | class EMAModel: | 151 | class EMAModel: |
130 | """ | 152 | """ |