summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 21:53:07 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 21:53:07 +0100
commit83808fe00ac891ad2f625388d144c318b2cb5bfe (patch)
treeb7ca19d27f90be6f02b14f4a39c62fc7250041a2 /training
parentTI: Prepare UNet with Accelerate as well (diff)
downloadtextual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz
textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2
textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip
WIP: Modularization ("free(): invalid pointer" my ass)
Diffstat (limited to 'training')
-rw-r--r--training/functional.py (renamed from training/common.py)29
-rw-r--r--training/lora.py107
-rw-r--r--training/util.py214
3 files changed, 130 insertions, 220 deletions
diff --git a/training/common.py b/training/functional.py
index 5d1e3f9..2d81eca 100644
--- a/training/common.py
+++ b/training/functional.py
@@ -16,19 +16,14 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
17from models.clip.util import get_extended_embeddings 17from models.clip.util import get_extended_embeddings
18from models.clip.tokenizer import MultiCLIPTokenizer 18from models.clip.tokenizer import MultiCLIPTokenizer
19from training.util import AverageMeter, CheckpointerBase 19from training.util import AverageMeter
20from trainer.base import Checkpointer
20 21
21 22
22def noop(*args, **kwards): 23def const(result=None):
23 pass 24 def fn(*args, **kwargs):
24 25 return result
25 26 return fn
26def noop_ctx(*args, **kwards):
27 return nullcontext()
28
29
30def noop_on_log():
31 return {}
32 27
33 28
34def generate_class_images( 29def generate_class_images(
@@ -210,7 +205,7 @@ def train_loop(
210 optimizer: torch.optim.Optimizer, 205 optimizer: torch.optim.Optimizer,
211 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 206 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
212 model: torch.nn.Module, 207 model: torch.nn.Module,
213 checkpointer: CheckpointerBase, 208 checkpointer: Checkpointer,
214 train_dataloader: DataLoader, 209 train_dataloader: DataLoader,
215 val_dataloader: DataLoader, 210 val_dataloader: DataLoader,
216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 211 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
@@ -218,11 +213,11 @@ def train_loop(
218 checkpoint_frequency: int = 50, 213 checkpoint_frequency: int = 50,
219 global_step_offset: int = 0, 214 global_step_offset: int = 0,
220 num_epochs: int = 100, 215 num_epochs: int = 100,
221 on_log: Callable[[], dict[str, Any]] = noop_on_log, 216 on_log: Callable[[], dict[str, Any]] = const({}),
222 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, 217 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
223 on_before_optimize: Callable[[int], None] = noop, 218 on_before_optimize: Callable[[int], None] = const(),
224 on_after_optimize: Callable[[float], None] = noop, 219 on_after_optimize: Callable[[float], None] = const(),
225 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx 220 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
226): 221):
227 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 222 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
228 num_val_steps_per_epoch = len(val_dataloader) 223 num_val_steps_per_epoch = len(val_dataloader)
diff --git a/training/lora.py b/training/lora.py
deleted file mode 100644
index 3857d78..0000000
--- a/training/lora.py
+++ /dev/null
@@ -1,107 +0,0 @@
1import torch
2import torch.nn as nn
3
4from diffusers import ModelMixin, ConfigMixin
5from diffusers.configuration_utils import register_to_config
6from diffusers.models.cross_attention import CrossAttention
7from diffusers.utils.import_utils import is_xformers_available
8
9
10if is_xformers_available():
11 import xformers
12 import xformers.ops
13else:
14 xformers = None
15
16
17class LoRALinearLayer(nn.Module):
18 def __init__(self, in_features, out_features, rank=4):
19 super().__init__()
20
21 if rank > min(in_features, out_features):
22 raise ValueError(
23 f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
24 )
25
26 self.lora_down = nn.Linear(in_features, rank, bias=False)
27 self.lora_up = nn.Linear(rank, out_features, bias=False)
28 self.scale = 1.0
29
30 nn.init.normal_(self.lora_down.weight, std=1 / rank)
31 nn.init.zeros_(self.lora_up.weight)
32
33 def forward(self, hidden_states):
34 down_hidden_states = self.lora_down(hidden_states)
35 up_hidden_states = self.lora_up(down_hidden_states)
36
37 return up_hidden_states
38
39
40class LoRACrossAttnProcessor(nn.Module):
41 def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
42 super().__init__()
43
44 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
45 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
46 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
47 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
48
49 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
50 batch_size, sequence_length, _ = hidden_states.shape
51 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
52
53 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
54 query = attn.head_to_batch_dim(query)
55
56 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
57
58 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
59 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
60
61 key = attn.head_to_batch_dim(key)
62 value = attn.head_to_batch_dim(value)
63
64 attention_probs = attn.get_attention_scores(query, key, attention_mask)
65 hidden_states = torch.bmm(attention_probs, value)
66 hidden_states = attn.batch_to_head_dim(hidden_states)
67
68 # linear proj
69 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
70 # dropout
71 hidden_states = attn.to_out[1](hidden_states)
72
73 return hidden_states
74
75
76class LoRAXFormersCrossAttnProcessor(nn.Module):
77 def __init__(self, hidden_size, cross_attention_dim, rank=4):
78 super().__init__()
79
80 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
81 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
82 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
83 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
84
85 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
86 batch_size, sequence_length, _ = hidden_states.shape
87 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
88
89 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
90 query = attn.head_to_batch_dim(query).contiguous()
91
92 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
93
94 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
95 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
96
97 key = attn.head_to_batch_dim(key).contiguous()
98 value = attn.head_to_batch_dim(value).contiguous()
99
100 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
101
102 # linear proj
103 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
104 # dropout
105 hidden_states = attn.to_out[1](hidden_states)
106
107 return hidden_states
diff --git a/training/util.py b/training/util.py
index 781cf04..a292edd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,12 +1,40 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4import itertools 4from typing import Iterable, Union
5from typing import Iterable, Optional
6from contextlib import contextmanager 5from contextlib import contextmanager
7 6
8import torch 7import torch
9from PIL import Image 8
9from transformers import CLIPTextModel
10from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
11
12from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
13from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15
16
17class TrainingStrategy():
18 @property
19 def main_model(self) -> torch.nn.Module:
20 ...
21
22 @contextmanager
23 def on_train(self, epoch: int):
24 yield
25
26 @contextmanager
27 def on_eval(self):
28 yield
29
30 def on_before_optimize(self, epoch: int):
31 ...
32
33 def on_after_optimize(self, lr: float):
34 ...
35
36 def on_log():
37 return {}
10 38
11 39
12def save_args(basepath: Path, args, extra={}): 40def save_args(basepath: Path, args, extra={}):
@@ -16,12 +44,93 @@ def save_args(basepath: Path, args, extra={}):
16 json.dump(info, f, indent=4) 44 json.dump(info, f, indent=4)
17 45
18 46
19def make_grid(images, rows, cols): 47def generate_class_images(
20 w, h = images[0].size 48 accelerator,
21 grid = Image.new('RGB', size=(cols*w, rows*h)) 49 text_encoder,
22 for i, image in enumerate(images): 50 vae,
23 grid.paste(image, box=(i % cols*w, i//cols*h)) 51 unet,
24 return grid 52 tokenizer,
53 scheduler,
54 data_train,
55 sample_batch_size,
56 sample_image_size,
57 sample_steps
58):
59 missing_data = [item for item in data_train if not item.class_image_path.exists()]
60
61 if len(missing_data) == 0:
62 return
63
64 batched_data = [
65 missing_data[i:i+sample_batch_size]
66 for i in range(0, len(missing_data), sample_batch_size)
67 ]
68
69 pipeline = VlpnStableDiffusion(
70 text_encoder=text_encoder,
71 vae=vae,
72 unet=unet,
73 tokenizer=tokenizer,
74 scheduler=scheduler,
75 ).to(accelerator.device)
76 pipeline.set_progress_bar_config(dynamic_ncols=True)
77
78 with torch.inference_mode():
79 for batch in batched_data:
80 image_name = [item.class_image_path for item in batch]
81 prompt = [item.cprompt for item in batch]
82 nprompt = [item.nprompt for item in batch]
83
84 images = pipeline(
85 prompt=prompt,
86 negative_prompt=nprompt,
87 height=sample_image_size,
88 width=sample_image_size,
89 num_inference_steps=sample_steps
90 ).images
91
92 for i, image in enumerate(images):
93 image.save(image_name[i])
94
95 del pipeline
96
97 if torch.cuda.is_available():
98 torch.cuda.empty_cache()
99
100
101def get_models(pretrained_model_name_or_path: str):
102 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
103 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
104 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
105 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
106 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
107 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
108 pretrained_model_name_or_path, subfolder='scheduler')
109
110 embeddings = patch_managed_embeddings(text_encoder)
111
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113
114
115def add_placeholder_tokens(
116 tokenizer: MultiCLIPTokenizer,
117 embeddings: ManagedCLIPTextEmbeddings,
118 placeholder_tokens: list[str],
119 initializer_tokens: list[str],
120 num_vectors: Union[list[int], int]
121):
122 initializer_token_ids = [
123 tokenizer.encode(token, add_special_tokens=False)
124 for token in initializer_tokens
125 ]
126 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
127
128 embeddings.resize(len(tokenizer))
129
130 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
131 embeddings.add_embed(placeholder_token_id, initializer_token_id)
132
133 return placeholder_token_ids, initializer_token_ids
25 134
26 135
27class AverageMeter: 136class AverageMeter:
@@ -38,93 +147,6 @@ class AverageMeter:
38 self.avg = self.sum / self.count 147 self.avg = self.sum / self.count
39 148
40 149
41class CheckpointerBase:
42 def __init__(
43 self,
44 train_dataloader,
45 val_dataloader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None
53 ):
54 self.train_dataloader = train_dataloader
55 self.val_dataloader = val_dataloader
56 self.output_dir = output_dir
57 self.sample_image_size = sample_image_size
58 self.sample_steps = sample_steps
59 self.sample_guidance_scale = sample_guidance_scale
60 self.sample_batches = sample_batches
61 self.sample_batch_size = sample_batch_size
62 self.seed = seed if seed is not None else torch.random.seed()
63
64 @torch.no_grad()
65 def checkpoint(self, step: int, postfix: str):
66 pass
67
68 @torch.inference_mode()
69 def save_samples(self, pipeline, step: int):
70 samples_path = Path(self.output_dir).joinpath("samples")
71
72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
73
74 grid_cols = min(self.sample_batch_size, 4)
75 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
76
77 for pool, data, gen in [
78 ("stable", self.val_dataloader, generator),
79 ("val", self.val_dataloader, None),
80 ("train", self.train_dataloader, None)
81 ]:
82 all_samples = []
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
84 file_path.parent.mkdir(parents=True, exist_ok=True)
85
86 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
87 prompt_ids = [
88 prompt
89 for batch in batches
90 for prompt in batch["prompt_ids"]
91 ]
92 nprompt_ids = [
93 prompt
94 for batch in batches
95 for prompt in batch["nprompt_ids"]
96 ]
97
98 for i in range(self.sample_batches):
99 start = i * self.sample_batch_size
100 end = (i + 1) * self.sample_batch_size
101 prompt = prompt_ids[start:end]
102 nprompt = nprompt_ids[start:end]
103
104 samples = pipeline(
105 prompt=prompt,
106 negative_prompt=nprompt,
107 height=self.sample_image_size,
108 width=self.sample_image_size,
109 generator=gen,
110 guidance_scale=self.sample_guidance_scale,
111 num_inference_steps=self.sample_steps,
112 output_type='pil'
113 ).images
114
115 all_samples += samples
116
117 del samples
118
119 image_grid = make_grid(all_samples, grid_rows, grid_cols)
120 image_grid.save(file_path, quality=85)
121
122 del all_samples
123 del image_grid
124
125 del generator
126
127
128# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 150# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
129class EMAModel: 151class EMAModel:
130 """ 152 """