diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py (renamed from training/common.py) | 29 | ||||
| -rw-r--r-- | training/lora.py | 107 | ||||
| -rw-r--r-- | training/util.py | 204 | 
3 files changed, 125 insertions, 215 deletions
| diff --git a/training/common.py b/training/functional.py index 5d1e3f9..2d81eca 100644 --- a/training/common.py +++ b/training/functional.py | |||
| @@ -16,19 +16,14 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 16 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 16 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 
| 17 | from models.clip.util import get_extended_embeddings | 17 | from models.clip.util import get_extended_embeddings | 
| 18 | from models.clip.tokenizer import MultiCLIPTokenizer | 18 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 19 | from training.util import AverageMeter, CheckpointerBase | 19 | from training.util import AverageMeter | 
| 20 | from trainer.base import Checkpointer | ||
| 20 | 21 | ||
| 21 | 22 | ||
| 22 | def noop(*args, **kwards): | 23 | def const(result=None): | 
| 23 | pass | 24 | def fn(*args, **kwargs): | 
| 24 | 25 | return result | |
| 25 | 26 | return fn | |
| 26 | def noop_ctx(*args, **kwards): | ||
| 27 | return nullcontext() | ||
| 28 | |||
| 29 | |||
| 30 | def noop_on_log(): | ||
| 31 | return {} | ||
| 32 | 27 | ||
| 33 | 28 | ||
| 34 | def generate_class_images( | 29 | def generate_class_images( | 
| @@ -210,7 +205,7 @@ def train_loop( | |||
| 210 | optimizer: torch.optim.Optimizer, | 205 | optimizer: torch.optim.Optimizer, | 
| 211 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 
| 212 | model: torch.nn.Module, | 207 | model: torch.nn.Module, | 
| 213 | checkpointer: CheckpointerBase, | 208 | checkpointer: Checkpointer, | 
| 214 | train_dataloader: DataLoader, | 209 | train_dataloader: DataLoader, | 
| 215 | val_dataloader: DataLoader, | 210 | val_dataloader: DataLoader, | 
| 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 
| @@ -218,11 +213,11 @@ def train_loop( | |||
| 218 | checkpoint_frequency: int = 50, | 213 | checkpoint_frequency: int = 50, | 
| 219 | global_step_offset: int = 0, | 214 | global_step_offset: int = 0, | 
| 220 | num_epochs: int = 100, | 215 | num_epochs: int = 100, | 
| 221 | on_log: Callable[[], dict[str, Any]] = noop_on_log, | 216 | on_log: Callable[[], dict[str, Any]] = const({}), | 
| 222 | on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, | 217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | 
| 223 | on_before_optimize: Callable[[int], None] = noop, | 218 | on_before_optimize: Callable[[int], None] = const(), | 
| 224 | on_after_optimize: Callable[[float], None] = noop, | 219 | on_after_optimize: Callable[[float], None] = const(), | 
| 225 | on_eval: Callable[[], _GeneratorContextManager] = noop_ctx | 220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 
| 226 | ): | 221 | ): | 
| 227 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 
| 228 | num_val_steps_per_epoch = len(val_dataloader) | 223 | num_val_steps_per_epoch = len(val_dataloader) | 
| diff --git a/training/lora.py b/training/lora.py deleted file mode 100644 index 3857d78..0000000 --- a/training/lora.py +++ /dev/null | |||
| @@ -1,107 +0,0 @@ | |||
| 1 | import torch | ||
| 2 | import torch.nn as nn | ||
| 3 | |||
| 4 | from diffusers import ModelMixin, ConfigMixin | ||
| 5 | from diffusers.configuration_utils import register_to_config | ||
| 6 | from diffusers.models.cross_attention import CrossAttention | ||
| 7 | from diffusers.utils.import_utils import is_xformers_available | ||
| 8 | |||
| 9 | |||
| 10 | if is_xformers_available(): | ||
| 11 | import xformers | ||
| 12 | import xformers.ops | ||
| 13 | else: | ||
| 14 | xformers = None | ||
| 15 | |||
| 16 | |||
| 17 | class LoRALinearLayer(nn.Module): | ||
| 18 | def __init__(self, in_features, out_features, rank=4): | ||
| 19 | super().__init__() | ||
| 20 | |||
| 21 | if rank > min(in_features, out_features): | ||
| 22 | raise ValueError( | ||
| 23 | f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" | ||
| 24 | ) | ||
| 25 | |||
| 26 | self.lora_down = nn.Linear(in_features, rank, bias=False) | ||
| 27 | self.lora_up = nn.Linear(rank, out_features, bias=False) | ||
| 28 | self.scale = 1.0 | ||
| 29 | |||
| 30 | nn.init.normal_(self.lora_down.weight, std=1 / rank) | ||
| 31 | nn.init.zeros_(self.lora_up.weight) | ||
| 32 | |||
| 33 | def forward(self, hidden_states): | ||
| 34 | down_hidden_states = self.lora_down(hidden_states) | ||
| 35 | up_hidden_states = self.lora_up(down_hidden_states) | ||
| 36 | |||
| 37 | return up_hidden_states | ||
| 38 | |||
| 39 | |||
| 40 | class LoRACrossAttnProcessor(nn.Module): | ||
| 41 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
| 42 | super().__init__() | ||
| 43 | |||
| 44 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 45 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 46 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 47 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 48 | |||
| 49 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
| 50 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 51 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
| 52 | |||
| 53 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
| 54 | query = attn.head_to_batch_dim(query) | ||
| 55 | |||
| 56 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 57 | |||
| 58 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
| 59 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
| 60 | |||
| 61 | key = attn.head_to_batch_dim(key) | ||
| 62 | value = attn.head_to_batch_dim(value) | ||
| 63 | |||
| 64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
| 65 | hidden_states = torch.bmm(attention_probs, value) | ||
| 66 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
| 67 | |||
| 68 | # linear proj | ||
| 69 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
| 70 | # dropout | ||
| 71 | hidden_states = attn.to_out[1](hidden_states) | ||
| 72 | |||
| 73 | return hidden_states | ||
| 74 | |||
| 75 | |||
| 76 | class LoRAXFormersCrossAttnProcessor(nn.Module): | ||
| 77 | def __init__(self, hidden_size, cross_attention_dim, rank=4): | ||
| 78 | super().__init__() | ||
| 79 | |||
| 80 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 81 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 82 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
| 83 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
| 84 | |||
| 85 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
| 86 | batch_size, sequence_length, _ = hidden_states.shape | ||
| 87 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
| 88 | |||
| 89 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
| 90 | query = attn.head_to_batch_dim(query).contiguous() | ||
| 91 | |||
| 92 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
| 93 | |||
| 94 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
| 95 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
| 96 | |||
| 97 | key = attn.head_to_batch_dim(key).contiguous() | ||
| 98 | value = attn.head_to_batch_dim(value).contiguous() | ||
| 99 | |||
| 100 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | ||
| 101 | |||
| 102 | # linear proj | ||
| 103 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
| 104 | # dropout | ||
| 105 | hidden_states = attn.to_out[1](hidden_states) | ||
| 106 | |||
| 107 | return hidden_states | ||
| diff --git a/training/util.py b/training/util.py index 781cf04..a292edd 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,12 +1,40 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path | 
| 2 | import json | 2 | import json | 
| 3 | import copy | 3 | import copy | 
| 4 | import itertools | 4 | from typing import Iterable, Union | 
| 5 | from typing import Iterable, Optional | ||
| 6 | from contextlib import contextmanager | 5 | from contextlib import contextmanager | 
| 7 | 6 | ||
| 8 | import torch | 7 | import torch | 
| 9 | from PIL import Image | 8 | |
| 9 | from transformers import CLIPTextModel | ||
| 10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
| 11 | |||
| 12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
| 15 | |||
| 16 | |||
| 17 | class TrainingStrategy(): | ||
| 18 | @property | ||
| 19 | def main_model(self) -> torch.nn.Module: | ||
| 20 | ... | ||
| 21 | |||
| 22 | @contextmanager | ||
| 23 | def on_train(self, epoch: int): | ||
| 24 | yield | ||
| 25 | |||
| 26 | @contextmanager | ||
| 27 | def on_eval(self): | ||
| 28 | yield | ||
| 29 | |||
| 30 | def on_before_optimize(self, epoch: int): | ||
| 31 | ... | ||
| 32 | |||
| 33 | def on_after_optimize(self, lr: float): | ||
| 34 | ... | ||
| 35 | |||
| 36 | def on_log(): | ||
| 37 | return {} | ||
| 10 | 38 | ||
| 11 | 39 | ||
| 12 | def save_args(basepath: Path, args, extra={}): | 40 | def save_args(basepath: Path, args, extra={}): | 
| @@ -16,113 +44,107 @@ def save_args(basepath: Path, args, extra={}): | |||
| 16 | json.dump(info, f, indent=4) | 44 | json.dump(info, f, indent=4) | 
| 17 | 45 | ||
| 18 | 46 | ||
| 19 | def make_grid(images, rows, cols): | 47 | def generate_class_images( | 
| 20 | w, h = images[0].size | 48 | accelerator, | 
| 21 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 49 | text_encoder, | 
| 22 | for i, image in enumerate(images): | 50 | vae, | 
| 23 | grid.paste(image, box=(i % cols*w, i//cols*h)) | 51 | unet, | 
| 24 | return grid | 52 | tokenizer, | 
| 53 | scheduler, | ||
| 54 | data_train, | ||
| 55 | sample_batch_size, | ||
| 56 | sample_image_size, | ||
| 57 | sample_steps | ||
| 58 | ): | ||
| 59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
| 25 | 60 | ||
| 61 | if len(missing_data) == 0: | ||
| 62 | return | ||
| 26 | 63 | ||
| 27 | class AverageMeter: | 64 | batched_data = [ | 
| 28 | def __init__(self, name=None): | 65 | missing_data[i:i+sample_batch_size] | 
| 29 | self.name = name | 66 | for i in range(0, len(missing_data), sample_batch_size) | 
| 30 | self.reset() | 67 | ] | 
| 31 | 68 | ||
| 32 | def reset(self): | 69 | pipeline = VlpnStableDiffusion( | 
| 33 | self.sum = self.count = self.avg = 0 | 70 | text_encoder=text_encoder, | 
| 71 | vae=vae, | ||
| 72 | unet=unet, | ||
| 73 | tokenizer=tokenizer, | ||
| 74 | scheduler=scheduler, | ||
| 75 | ).to(accelerator.device) | ||
| 76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 34 | 77 | ||
| 35 | def update(self, val, n=1): | 78 | with torch.inference_mode(): | 
| 36 | self.sum += val * n | 79 | for batch in batched_data: | 
| 37 | self.count += n | 80 | image_name = [item.class_image_path for item in batch] | 
| 38 | self.avg = self.sum / self.count | 81 | prompt = [item.cprompt for item in batch] | 
| 82 | nprompt = [item.nprompt for item in batch] | ||
| 39 | 83 | ||
| 84 | images = pipeline( | ||
| 85 | prompt=prompt, | ||
| 86 | negative_prompt=nprompt, | ||
| 87 | height=sample_image_size, | ||
| 88 | width=sample_image_size, | ||
| 89 | num_inference_steps=sample_steps | ||
| 90 | ).images | ||
| 40 | 91 | ||
| 41 | class CheckpointerBase: | 92 | for i, image in enumerate(images): | 
| 42 | def __init__( | 93 | image.save(image_name[i]) | 
| 43 | self, | ||
| 44 | train_dataloader, | ||
| 45 | val_dataloader, | ||
| 46 | output_dir: Path, | ||
| 47 | sample_steps: int = 20, | ||
| 48 | sample_guidance_scale: float = 7.5, | ||
| 49 | sample_image_size: int = 768, | ||
| 50 | sample_batches: int = 1, | ||
| 51 | sample_batch_size: int = 1, | ||
| 52 | seed: Optional[int] = None | ||
| 53 | ): | ||
| 54 | self.train_dataloader = train_dataloader | ||
| 55 | self.val_dataloader = val_dataloader | ||
| 56 | self.output_dir = output_dir | ||
| 57 | self.sample_image_size = sample_image_size | ||
| 58 | self.sample_steps = sample_steps | ||
| 59 | self.sample_guidance_scale = sample_guidance_scale | ||
| 60 | self.sample_batches = sample_batches | ||
| 61 | self.sample_batch_size = sample_batch_size | ||
| 62 | self.seed = seed if seed is not None else torch.random.seed() | ||
| 63 | 94 | ||
| 64 | @torch.no_grad() | 95 | del pipeline | 
| 65 | def checkpoint(self, step: int, postfix: str): | ||
| 66 | pass | ||
| 67 | 96 | ||
| 68 | @torch.inference_mode() | 97 | if torch.cuda.is_available(): | 
| 69 | def save_samples(self, pipeline, step: int): | 98 | torch.cuda.empty_cache() | 
| 70 | samples_path = Path(self.output_dir).joinpath("samples") | ||
| 71 | 99 | ||
| 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
| 73 | 100 | ||
| 74 | grid_cols = min(self.sample_batch_size, 4) | 101 | def get_models(pretrained_model_name_or_path: str): | 
| 75 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | 102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 
| 103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 76 | 109 | ||
| 77 | for pool, data, gen in [ | 110 | embeddings = patch_managed_embeddings(text_encoder) | 
| 78 | ("stable", self.val_dataloader, generator), | ||
| 79 | ("val", self.val_dataloader, None), | ||
| 80 | ("train", self.train_dataloader, None) | ||
| 81 | ]: | ||
| 82 | all_samples = [] | ||
| 83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
| 84 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 85 | 111 | ||
| 86 | batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) | 112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | 
| 87 | prompt_ids = [ | ||
| 88 | prompt | ||
| 89 | for batch in batches | ||
| 90 | for prompt in batch["prompt_ids"] | ||
| 91 | ] | ||
| 92 | nprompt_ids = [ | ||
| 93 | prompt | ||
| 94 | for batch in batches | ||
| 95 | for prompt in batch["nprompt_ids"] | ||
| 96 | ] | ||
| 97 | 113 | ||
| 98 | for i in range(self.sample_batches): | ||
| 99 | start = i * self.sample_batch_size | ||
| 100 | end = (i + 1) * self.sample_batch_size | ||
| 101 | prompt = prompt_ids[start:end] | ||
| 102 | nprompt = nprompt_ids[start:end] | ||
| 103 | 114 | ||
| 104 | samples = pipeline( | 115 | def add_placeholder_tokens( | 
| 105 | prompt=prompt, | 116 | tokenizer: MultiCLIPTokenizer, | 
| 106 | negative_prompt=nprompt, | 117 | embeddings: ManagedCLIPTextEmbeddings, | 
| 107 | height=self.sample_image_size, | 118 | placeholder_tokens: list[str], | 
| 108 | width=self.sample_image_size, | 119 | initializer_tokens: list[str], | 
| 109 | generator=gen, | 120 | num_vectors: Union[list[int], int] | 
| 110 | guidance_scale=self.sample_guidance_scale, | 121 | ): | 
| 111 | num_inference_steps=self.sample_steps, | 122 | initializer_token_ids = [ | 
| 112 | output_type='pil' | 123 | tokenizer.encode(token, add_special_tokens=False) | 
| 113 | ).images | 124 | for token in initializer_tokens | 
| 125 | ] | ||
| 126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
| 114 | 127 | ||
| 115 | all_samples += samples | 128 | embeddings.resize(len(tokenizer)) | 
| 116 | 129 | ||
| 117 | del samples | 130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 
| 131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
| 118 | 132 | ||
| 119 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 133 | return placeholder_token_ids, initializer_token_ids | 
| 120 | image_grid.save(file_path, quality=85) | ||
| 121 | 134 | ||
| 122 | del all_samples | ||
| 123 | del image_grid | ||
| 124 | 135 | ||
| 125 | del generator | 136 | class AverageMeter: | 
| 137 | def __init__(self, name=None): | ||
| 138 | self.name = name | ||
| 139 | self.reset() | ||
| 140 | |||
| 141 | def reset(self): | ||
| 142 | self.sum = self.count = self.avg = 0 | ||
| 143 | |||
| 144 | def update(self, val, n=1): | ||
| 145 | self.sum += val * n | ||
| 146 | self.count += n | ||
| 147 | self.avg = self.sum / self.count | ||
| 126 | 148 | ||
| 127 | 149 | ||
| 128 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 150 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 | 
