diff options
author | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-14 21:53:07 +0100 |
commit | 83808fe00ac891ad2f625388d144c318b2cb5bfe (patch) | |
tree | b7ca19d27f90be6f02b14f4a39c62fc7250041a2 /training | |
parent | TI: Prepare UNet with Accelerate as well (diff) | |
download | textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2 textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip |
WIP: Modularization ("free(): invalid pointer" my ass)
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py (renamed from training/common.py) | 29 | ||||
-rw-r--r-- | training/lora.py | 107 | ||||
-rw-r--r-- | training/util.py | 214 |
3 files changed, 130 insertions, 220 deletions
diff --git a/training/common.py b/training/functional.py index 5d1e3f9..2d81eca 100644 --- a/training/common.py +++ b/training/functional.py | |||
@@ -16,19 +16,14 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
16 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 16 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
17 | from models.clip.util import get_extended_embeddings | 17 | from models.clip.util import get_extended_embeddings |
18 | from models.clip.tokenizer import MultiCLIPTokenizer | 18 | from models.clip.tokenizer import MultiCLIPTokenizer |
19 | from training.util import AverageMeter, CheckpointerBase | 19 | from training.util import AverageMeter |
20 | from trainer.base import Checkpointer | ||
20 | 21 | ||
21 | 22 | ||
22 | def noop(*args, **kwards): | 23 | def const(result=None): |
23 | pass | 24 | def fn(*args, **kwargs): |
24 | 25 | return result | |
25 | 26 | return fn | |
26 | def noop_ctx(*args, **kwards): | ||
27 | return nullcontext() | ||
28 | |||
29 | |||
30 | def noop_on_log(): | ||
31 | return {} | ||
32 | 27 | ||
33 | 28 | ||
34 | def generate_class_images( | 29 | def generate_class_images( |
@@ -210,7 +205,7 @@ def train_loop( | |||
210 | optimizer: torch.optim.Optimizer, | 205 | optimizer: torch.optim.Optimizer, |
211 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
212 | model: torch.nn.Module, | 207 | model: torch.nn.Module, |
213 | checkpointer: CheckpointerBase, | 208 | checkpointer: Checkpointer, |
214 | train_dataloader: DataLoader, | 209 | train_dataloader: DataLoader, |
215 | val_dataloader: DataLoader, | 210 | val_dataloader: DataLoader, |
216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
@@ -218,11 +213,11 @@ def train_loop( | |||
218 | checkpoint_frequency: int = 50, | 213 | checkpoint_frequency: int = 50, |
219 | global_step_offset: int = 0, | 214 | global_step_offset: int = 0, |
220 | num_epochs: int = 100, | 215 | num_epochs: int = 100, |
221 | on_log: Callable[[], dict[str, Any]] = noop_on_log, | 216 | on_log: Callable[[], dict[str, Any]] = const({}), |
222 | on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, | 217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), |
223 | on_before_optimize: Callable[[int], None] = noop, | 218 | on_before_optimize: Callable[[int], None] = const(), |
224 | on_after_optimize: Callable[[float], None] = noop, | 219 | on_after_optimize: Callable[[float], None] = const(), |
225 | on_eval: Callable[[], _GeneratorContextManager] = noop_ctx | 220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
226 | ): | 221 | ): |
227 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
228 | num_val_steps_per_epoch = len(val_dataloader) | 223 | num_val_steps_per_epoch = len(val_dataloader) |
diff --git a/training/lora.py b/training/lora.py deleted file mode 100644 index 3857d78..0000000 --- a/training/lora.py +++ /dev/null | |||
@@ -1,107 +0,0 @@ | |||
1 | import torch | ||
2 | import torch.nn as nn | ||
3 | |||
4 | from diffusers import ModelMixin, ConfigMixin | ||
5 | from diffusers.configuration_utils import register_to_config | ||
6 | from diffusers.models.cross_attention import CrossAttention | ||
7 | from diffusers.utils.import_utils import is_xformers_available | ||
8 | |||
9 | |||
10 | if is_xformers_available(): | ||
11 | import xformers | ||
12 | import xformers.ops | ||
13 | else: | ||
14 | xformers = None | ||
15 | |||
16 | |||
17 | class LoRALinearLayer(nn.Module): | ||
18 | def __init__(self, in_features, out_features, rank=4): | ||
19 | super().__init__() | ||
20 | |||
21 | if rank > min(in_features, out_features): | ||
22 | raise ValueError( | ||
23 | f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" | ||
24 | ) | ||
25 | |||
26 | self.lora_down = nn.Linear(in_features, rank, bias=False) | ||
27 | self.lora_up = nn.Linear(rank, out_features, bias=False) | ||
28 | self.scale = 1.0 | ||
29 | |||
30 | nn.init.normal_(self.lora_down.weight, std=1 / rank) | ||
31 | nn.init.zeros_(self.lora_up.weight) | ||
32 | |||
33 | def forward(self, hidden_states): | ||
34 | down_hidden_states = self.lora_down(hidden_states) | ||
35 | up_hidden_states = self.lora_up(down_hidden_states) | ||
36 | |||
37 | return up_hidden_states | ||
38 | |||
39 | |||
40 | class LoRACrossAttnProcessor(nn.Module): | ||
41 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
42 | super().__init__() | ||
43 | |||
44 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
45 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
46 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
47 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
48 | |||
49 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
50 | batch_size, sequence_length, _ = hidden_states.shape | ||
51 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
52 | |||
53 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
54 | query = attn.head_to_batch_dim(query) | ||
55 | |||
56 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
57 | |||
58 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
59 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
60 | |||
61 | key = attn.head_to_batch_dim(key) | ||
62 | value = attn.head_to_batch_dim(value) | ||
63 | |||
64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
65 | hidden_states = torch.bmm(attention_probs, value) | ||
66 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
67 | |||
68 | # linear proj | ||
69 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
70 | # dropout | ||
71 | hidden_states = attn.to_out[1](hidden_states) | ||
72 | |||
73 | return hidden_states | ||
74 | |||
75 | |||
76 | class LoRAXFormersCrossAttnProcessor(nn.Module): | ||
77 | def __init__(self, hidden_size, cross_attention_dim, rank=4): | ||
78 | super().__init__() | ||
79 | |||
80 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
81 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
82 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
83 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
84 | |||
85 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
86 | batch_size, sequence_length, _ = hidden_states.shape | ||
87 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
88 | |||
89 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
90 | query = attn.head_to_batch_dim(query).contiguous() | ||
91 | |||
92 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
93 | |||
94 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
95 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
96 | |||
97 | key = attn.head_to_batch_dim(key).contiguous() | ||
98 | value = attn.head_to_batch_dim(value).contiguous() | ||
99 | |||
100 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
101 | |||
102 | # linear proj | ||
103 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
104 | # dropout | ||
105 | hidden_states = attn.to_out[1](hidden_states) | ||
106 | |||
107 | return hidden_states | ||
diff --git a/training/util.py b/training/util.py index 781cf04..a292edd 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,12 +1,40 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | import itertools | 4 | from typing import Iterable, Union |
5 | from typing import Iterable, Optional | ||
6 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
7 | 6 | ||
8 | import torch | 7 | import torch |
9 | from PIL import Image | 8 | |
9 | from transformers import CLIPTextModel | ||
10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
11 | |||
12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
15 | |||
16 | |||
17 | class TrainingStrategy(): | ||
18 | @property | ||
19 | def main_model(self) -> torch.nn.Module: | ||
20 | ... | ||
21 | |||
22 | @contextmanager | ||
23 | def on_train(self, epoch: int): | ||
24 | yield | ||
25 | |||
26 | @contextmanager | ||
27 | def on_eval(self): | ||
28 | yield | ||
29 | |||
30 | def on_before_optimize(self, epoch: int): | ||
31 | ... | ||
32 | |||
33 | def on_after_optimize(self, lr: float): | ||
34 | ... | ||
35 | |||
36 | def on_log(): | ||
37 | return {} | ||
10 | 38 | ||
11 | 39 | ||
12 | def save_args(basepath: Path, args, extra={}): | 40 | def save_args(basepath: Path, args, extra={}): |
@@ -16,12 +44,93 @@ def save_args(basepath: Path, args, extra={}): | |||
16 | json.dump(info, f, indent=4) | 44 | json.dump(info, f, indent=4) |
17 | 45 | ||
18 | 46 | ||
19 | def make_grid(images, rows, cols): | 47 | def generate_class_images( |
20 | w, h = images[0].size | 48 | accelerator, |
21 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 49 | text_encoder, |
22 | for i, image in enumerate(images): | 50 | vae, |
23 | grid.paste(image, box=(i % cols*w, i//cols*h)) | 51 | unet, |
24 | return grid | 52 | tokenizer, |
53 | scheduler, | ||
54 | data_train, | ||
55 | sample_batch_size, | ||
56 | sample_image_size, | ||
57 | sample_steps | ||
58 | ): | ||
59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
60 | |||
61 | if len(missing_data) == 0: | ||
62 | return | ||
63 | |||
64 | batched_data = [ | ||
65 | missing_data[i:i+sample_batch_size] | ||
66 | for i in range(0, len(missing_data), sample_batch_size) | ||
67 | ] | ||
68 | |||
69 | pipeline = VlpnStableDiffusion( | ||
70 | text_encoder=text_encoder, | ||
71 | vae=vae, | ||
72 | unet=unet, | ||
73 | tokenizer=tokenizer, | ||
74 | scheduler=scheduler, | ||
75 | ).to(accelerator.device) | ||
76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
77 | |||
78 | with torch.inference_mode(): | ||
79 | for batch in batched_data: | ||
80 | image_name = [item.class_image_path for item in batch] | ||
81 | prompt = [item.cprompt for item in batch] | ||
82 | nprompt = [item.nprompt for item in batch] | ||
83 | |||
84 | images = pipeline( | ||
85 | prompt=prompt, | ||
86 | negative_prompt=nprompt, | ||
87 | height=sample_image_size, | ||
88 | width=sample_image_size, | ||
89 | num_inference_steps=sample_steps | ||
90 | ).images | ||
91 | |||
92 | for i, image in enumerate(images): | ||
93 | image.save(image_name[i]) | ||
94 | |||
95 | del pipeline | ||
96 | |||
97 | if torch.cuda.is_available(): | ||
98 | torch.cuda.empty_cache() | ||
99 | |||
100 | |||
101 | def get_models(pretrained_model_name_or_path: str): | ||
102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
109 | |||
110 | embeddings = patch_managed_embeddings(text_encoder) | ||
111 | |||
112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
113 | |||
114 | |||
115 | def add_placeholder_tokens( | ||
116 | tokenizer: MultiCLIPTokenizer, | ||
117 | embeddings: ManagedCLIPTextEmbeddings, | ||
118 | placeholder_tokens: list[str], | ||
119 | initializer_tokens: list[str], | ||
120 | num_vectors: Union[list[int], int] | ||
121 | ): | ||
122 | initializer_token_ids = [ | ||
123 | tokenizer.encode(token, add_special_tokens=False) | ||
124 | for token in initializer_tokens | ||
125 | ] | ||
126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
127 | |||
128 | embeddings.resize(len(tokenizer)) | ||
129 | |||
130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
132 | |||
133 | return placeholder_token_ids, initializer_token_ids | ||
25 | 134 | ||
26 | 135 | ||
27 | class AverageMeter: | 136 | class AverageMeter: |
@@ -38,93 +147,6 @@ class AverageMeter: | |||
38 | self.avg = self.sum / self.count | 147 | self.avg = self.sum / self.count |
39 | 148 | ||
40 | 149 | ||
41 | class CheckpointerBase: | ||
42 | def __init__( | ||
43 | self, | ||
44 | train_dataloader, | ||
45 | val_dataloader, | ||
46 | output_dir: Path, | ||
47 | sample_steps: int = 20, | ||
48 | sample_guidance_scale: float = 7.5, | ||
49 | sample_image_size: int = 768, | ||
50 | sample_batches: int = 1, | ||
51 | sample_batch_size: int = 1, | ||
52 | seed: Optional[int] = None | ||
53 | ): | ||
54 | self.train_dataloader = train_dataloader | ||
55 | self.val_dataloader = val_dataloader | ||
56 | self.output_dir = output_dir | ||
57 | self.sample_image_size = sample_image_size | ||
58 | self.sample_steps = sample_steps | ||
59 | self.sample_guidance_scale = sample_guidance_scale | ||
60 | self.sample_batches = sample_batches | ||
61 | self.sample_batch_size = sample_batch_size | ||
62 | self.seed = seed if seed is not None else torch.random.seed() | ||
63 | |||
64 | @torch.no_grad() | ||
65 | def checkpoint(self, step: int, postfix: str): | ||
66 | pass | ||
67 | |||
68 | @torch.inference_mode() | ||
69 | def save_samples(self, pipeline, step: int): | ||
70 | samples_path = Path(self.output_dir).joinpath("samples") | ||
71 | |||
72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
73 | |||
74 | grid_cols = min(self.sample_batch_size, 4) | ||
75 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | ||
76 | |||
77 | for pool, data, gen in [ | ||
78 | ("stable", self.val_dataloader, generator), | ||
79 | ("val", self.val_dataloader, None), | ||
80 | ("train", self.train_dataloader, None) | ||
81 | ]: | ||
82 | all_samples = [] | ||
83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
84 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
85 | |||
86 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | ||
87 | prompt_ids = [ | ||
88 | prompt | ||
89 | for batch in batches | ||
90 | for prompt in batch["prompt_ids"] | ||
91 | ] | ||
92 | nprompt_ids = [ | ||
93 | prompt | ||
94 | for batch in batches | ||
95 | for prompt in batch["nprompt_ids"] | ||
96 | ] | ||
97 | |||
98 | for i in range(self.sample_batches): | ||
99 | start = i * self.sample_batch_size | ||
100 | end = (i + 1) * self.sample_batch_size | ||
101 | prompt = prompt_ids[start:end] | ||
102 | nprompt = nprompt_ids[start:end] | ||
103 | |||
104 | samples = pipeline( | ||
105 | prompt=prompt, | ||
106 | negative_prompt=nprompt, | ||
107 | height=self.sample_image_size, | ||
108 | width=self.sample_image_size, | ||
109 | generator=gen, | ||
110 | guidance_scale=self.sample_guidance_scale, | ||
111 | num_inference_steps=self.sample_steps, | ||
112 | output_type='pil' | ||
113 | ).images | ||
114 | |||
115 | all_samples += samples | ||
116 | |||
117 | del samples | ||
118 | |||
119 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | ||
120 | image_grid.save(file_path, quality=85) | ||
121 | |||
122 | del all_samples | ||
123 | del image_grid | ||
124 | |||
125 | del generator | ||
126 | |||
127 | |||
128 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 150 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 |
129 | class EMAModel: | 151 | class EMAModel: |
130 | """ | 152 | """ |