diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 365 |
1 files changed, 365 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py new file mode 100644 index 0000000..2d81eca --- /dev/null +++ b/training/functional.py | |||
| @@ -0,0 +1,365 @@ | |||
| 1 | import math | ||
| 2 | from contextlib import _GeneratorContextManager, nullcontext | ||
| 3 | from typing import Callable, Any, Tuple, Union | ||
| 4 | |||
| 5 | import torch | ||
| 6 | import torch.nn.functional as F | ||
| 7 | from torch.utils.data import DataLoader | ||
| 8 | |||
| 9 | from accelerate import Accelerator | ||
| 10 | from transformers import CLIPTextModel | ||
| 11 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
| 12 | |||
| 13 | from tqdm.auto import tqdm | ||
| 14 | |||
| 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 16 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
| 17 | from models.clip.util import get_extended_embeddings | ||
| 18 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 19 | from training.util import AverageMeter | ||
| 20 | from trainer.base import Checkpointer | ||
| 21 | |||
| 22 | |||
| 23 | def const(result=None): | ||
| 24 | def fn(*args, **kwargs): | ||
| 25 | return result | ||
| 26 | return fn | ||
| 27 | |||
| 28 | |||
| 29 | def generate_class_images( | ||
| 30 | accelerator, | ||
| 31 | text_encoder, | ||
| 32 | vae, | ||
| 33 | unet, | ||
| 34 | tokenizer, | ||
| 35 | scheduler, | ||
| 36 | data_train, | ||
| 37 | sample_batch_size, | ||
| 38 | sample_image_size, | ||
| 39 | sample_steps | ||
| 40 | ): | ||
| 41 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
| 42 | |||
| 43 | if len(missing_data) == 0: | ||
| 44 | return | ||
| 45 | |||
| 46 | batched_data = [ | ||
| 47 | missing_data[i:i+sample_batch_size] | ||
| 48 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 49 | ] | ||
| 50 | |||
| 51 | pipeline = VlpnStableDiffusion( | ||
| 52 | text_encoder=text_encoder, | ||
| 53 | vae=vae, | ||
| 54 | unet=unet, | ||
| 55 | tokenizer=tokenizer, | ||
| 56 | scheduler=scheduler, | ||
| 57 | ).to(accelerator.device) | ||
| 58 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 59 | |||
| 60 | with torch.inference_mode(): | ||
| 61 | for batch in batched_data: | ||
| 62 | image_name = [item.class_image_path for item in batch] | ||
| 63 | prompt = [item.cprompt for item in batch] | ||
| 64 | nprompt = [item.nprompt for item in batch] | ||
| 65 | |||
| 66 | images = pipeline( | ||
| 67 | prompt=prompt, | ||
| 68 | negative_prompt=nprompt, | ||
| 69 | height=sample_image_size, | ||
| 70 | width=sample_image_size, | ||
| 71 | num_inference_steps=sample_steps | ||
| 72 | ).images | ||
| 73 | |||
| 74 | for i, image in enumerate(images): | ||
| 75 | image.save(image_name[i]) | ||
| 76 | |||
| 77 | del pipeline | ||
| 78 | |||
| 79 | if torch.cuda.is_available(): | ||
| 80 | torch.cuda.empty_cache() | ||
| 81 | |||
| 82 | |||
| 83 | def get_models(pretrained_model_name_or_path: str): | ||
| 84 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 85 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 86 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 87 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 88 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 89 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 90 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 91 | |||
| 92 | vae.enable_slicing() | ||
| 93 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 94 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 95 | |||
| 96 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 97 | |||
| 98 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 99 | |||
| 100 | |||
| 101 | def add_placeholder_tokens( | ||
| 102 | tokenizer: MultiCLIPTokenizer, | ||
| 103 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 104 | placeholder_tokens: list[str], | ||
| 105 | initializer_tokens: list[str], | ||
| 106 | num_vectors: Union[list[int], int] | ||
| 107 | ): | ||
| 108 | initializer_token_ids = [ | ||
| 109 | tokenizer.encode(token, add_special_tokens=False) | ||
| 110 | for token in initializer_tokens | ||
| 111 | ] | ||
| 112 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
| 113 | |||
| 114 | embeddings.resize(len(tokenizer)) | ||
| 115 | |||
| 116 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
| 117 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
| 118 | |||
| 119 | return placeholder_token_ids, initializer_token_ids | ||
| 120 | |||
| 121 | |||
| 122 | def loss_step( | ||
| 123 | vae: AutoencoderKL, | ||
| 124 | noise_scheduler: DDPMScheduler, | ||
| 125 | unet: UNet2DConditionModel, | ||
| 126 | text_encoder: CLIPTextModel, | ||
| 127 | prior_loss_weight: float, | ||
| 128 | seed: int, | ||
| 129 | step: int, | ||
| 130 | batch: dict[str, Any], | ||
| 131 | eval: bool = False | ||
| 132 | ): | ||
| 133 | # Convert images to latent space | ||
| 134 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
| 135 | latents = latents * 0.18215 | ||
| 136 | |||
| 137 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 138 | |||
| 139 | # Sample noise that we'll add to the latents | ||
| 140 | noise = torch.randn( | ||
| 141 | latents.shape, | ||
| 142 | dtype=latents.dtype, | ||
| 143 | layout=latents.layout, | ||
| 144 | device=latents.device, | ||
| 145 | generator=generator | ||
| 146 | ) | ||
| 147 | bsz = latents.shape[0] | ||
| 148 | # Sample a random timestep for each image | ||
| 149 | timesteps = torch.randint( | ||
| 150 | 0, | ||
| 151 | noise_scheduler.config.num_train_timesteps, | ||
| 152 | (bsz,), | ||
| 153 | generator=generator, | ||
| 154 | device=latents.device, | ||
| 155 | ) | ||
| 156 | timesteps = timesteps.long() | ||
| 157 | |||
| 158 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 159 | # (this is the forward diffusion process) | ||
| 160 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 161 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 162 | |||
| 163 | # Get the text embedding for conditioning | ||
| 164 | encoder_hidden_states = get_extended_embeddings( | ||
| 165 | text_encoder, | ||
| 166 | batch["input_ids"], | ||
| 167 | batch["attention_mask"] | ||
| 168 | ) | ||
| 169 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | ||
| 170 | |||
| 171 | # Predict the noise residual | ||
| 172 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 173 | |||
| 174 | # Get the target for loss depending on the prediction type | ||
| 175 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 176 | target = noise | ||
| 177 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 178 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 179 | else: | ||
| 180 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 181 | |||
| 182 | if batch["with_prior"].all(): | ||
| 183 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 184 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 185 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 186 | |||
| 187 | # Compute instance loss | ||
| 188 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 189 | |||
| 190 | # Compute prior loss | ||
| 191 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 192 | |||
| 193 | # Add the prior loss to the instance loss. | ||
| 194 | loss = loss + prior_loss_weight * prior_loss | ||
| 195 | else: | ||
| 196 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 197 | |||
| 198 | acc = (model_pred == target).float().mean() | ||
| 199 | |||
| 200 | return loss, acc, bsz | ||
| 201 | |||
| 202 | |||
| 203 | def train_loop( | ||
| 204 | accelerator: Accelerator, | ||
| 205 | optimizer: torch.optim.Optimizer, | ||
| 206 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 207 | model: torch.nn.Module, | ||
| 208 | checkpointer: Checkpointer, | ||
| 209 | train_dataloader: DataLoader, | ||
| 210 | val_dataloader: DataLoader, | ||
| 211 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | ||
| 212 | sample_frequency: int = 10, | ||
| 213 | checkpoint_frequency: int = 50, | ||
| 214 | global_step_offset: int = 0, | ||
| 215 | num_epochs: int = 100, | ||
| 216 | on_log: Callable[[], dict[str, Any]] = const({}), | ||
| 217 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()), | ||
| 218 | on_before_optimize: Callable[[int], None] = const(), | ||
| 219 | on_after_optimize: Callable[[float], None] = const(), | ||
| 220 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | ||
| 221 | ): | ||
| 222 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | ||
| 223 | num_val_steps_per_epoch = len(val_dataloader) | ||
| 224 | |||
| 225 | num_training_steps = num_training_steps_per_epoch * num_epochs | ||
| 226 | num_val_steps = num_val_steps_per_epoch * num_epochs | ||
| 227 | |||
| 228 | global_step = 0 | ||
| 229 | |||
| 230 | avg_loss = AverageMeter() | ||
| 231 | avg_acc = AverageMeter() | ||
| 232 | |||
| 233 | avg_loss_val = AverageMeter() | ||
| 234 | avg_acc_val = AverageMeter() | ||
| 235 | |||
| 236 | max_acc_val = 0.0 | ||
| 237 | |||
| 238 | local_progress_bar = tqdm( | ||
| 239 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | ||
| 240 | disable=not accelerator.is_local_main_process, | ||
| 241 | dynamic_ncols=True | ||
| 242 | ) | ||
| 243 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | ||
| 244 | |||
| 245 | global_progress_bar = tqdm( | ||
| 246 | range(num_training_steps + num_val_steps), | ||
| 247 | disable=not accelerator.is_local_main_process, | ||
| 248 | dynamic_ncols=True | ||
| 249 | ) | ||
| 250 | global_progress_bar.set_description("Total progress") | ||
| 251 | |||
| 252 | try: | ||
| 253 | for epoch in range(num_epochs): | ||
| 254 | if accelerator.is_main_process: | ||
| 255 | if epoch % sample_frequency == 0: | ||
| 256 | checkpointer.save_samples(global_step + global_step_offset) | ||
| 257 | |||
| 258 | if epoch % checkpoint_frequency == 0 and epoch != 0: | ||
| 259 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 260 | |||
| 261 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | ||
| 262 | local_progress_bar.reset() | ||
| 263 | |||
| 264 | model.train() | ||
| 265 | |||
| 266 | with on_train(epoch): | ||
| 267 | for step, batch in enumerate(train_dataloader): | ||
| 268 | with accelerator.accumulate(model): | ||
| 269 | loss, acc, bsz = loss_step(step, batch) | ||
| 270 | |||
| 271 | accelerator.backward(loss) | ||
| 272 | |||
| 273 | on_before_optimize(epoch) | ||
| 274 | |||
| 275 | optimizer.step() | ||
| 276 | lr_scheduler.step() | ||
| 277 | optimizer.zero_grad(set_to_none=True) | ||
| 278 | |||
| 279 | avg_loss.update(loss.detach_(), bsz) | ||
| 280 | avg_acc.update(acc.detach_(), bsz) | ||
| 281 | |||
| 282 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
| 283 | if accelerator.sync_gradients: | ||
| 284 | on_after_optimize(lr_scheduler.get_last_lr()[0]) | ||
| 285 | |||
| 286 | local_progress_bar.update(1) | ||
| 287 | global_progress_bar.update(1) | ||
| 288 | |||
| 289 | global_step += 1 | ||
| 290 | |||
| 291 | logs = { | ||
| 292 | "train/loss": avg_loss.avg.item(), | ||
| 293 | "train/acc": avg_acc.avg.item(), | ||
| 294 | "train/cur_loss": loss.item(), | ||
| 295 | "train/cur_acc": acc.item(), | ||
| 296 | "lr": lr_scheduler.get_last_lr()[0], | ||
| 297 | } | ||
| 298 | logs.update(on_log()) | ||
| 299 | |||
| 300 | accelerator.log(logs, step=global_step) | ||
| 301 | |||
| 302 | local_progress_bar.set_postfix(**logs) | ||
| 303 | |||
| 304 | if global_step >= num_training_steps: | ||
| 305 | break | ||
| 306 | |||
| 307 | accelerator.wait_for_everyone() | ||
| 308 | |||
| 309 | model.eval() | ||
| 310 | |||
| 311 | cur_loss_val = AverageMeter() | ||
| 312 | cur_acc_val = AverageMeter() | ||
| 313 | |||
| 314 | with torch.inference_mode(), on_eval(): | ||
| 315 | for step, batch in enumerate(val_dataloader): | ||
| 316 | loss, acc, bsz = loss_step(step, batch, True) | ||
| 317 | |||
| 318 | loss = loss.detach_() | ||
| 319 | acc = acc.detach_() | ||
| 320 | |||
| 321 | cur_loss_val.update(loss, bsz) | ||
| 322 | cur_acc_val.update(acc, bsz) | ||
| 323 | |||
| 324 | avg_loss_val.update(loss, bsz) | ||
| 325 | avg_acc_val.update(acc, bsz) | ||
| 326 | |||
| 327 | local_progress_bar.update(1) | ||
| 328 | global_progress_bar.update(1) | ||
| 329 | |||
| 330 | logs = { | ||
| 331 | "val/loss": avg_loss_val.avg.item(), | ||
| 332 | "val/acc": avg_acc_val.avg.item(), | ||
| 333 | "val/cur_loss": loss.item(), | ||
| 334 | "val/cur_acc": acc.item(), | ||
| 335 | } | ||
| 336 | local_progress_bar.set_postfix(**logs) | ||
| 337 | |||
| 338 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
| 339 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
| 340 | |||
| 341 | accelerator.log(logs, step=global_step) | ||
| 342 | |||
| 343 | local_progress_bar.clear() | ||
| 344 | global_progress_bar.clear() | ||
| 345 | |||
| 346 | if accelerator.is_main_process: | ||
| 347 | if avg_acc_val.avg.item() > max_acc_val: | ||
| 348 | accelerator.print( | ||
| 349 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
| 350 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
| 351 | max_acc_val = avg_acc_val.avg.item() | ||
| 352 | |||
| 353 | # Create the pipeline using using the trained modules and save it. | ||
| 354 | if accelerator.is_main_process: | ||
| 355 | print("Finished!") | ||
| 356 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
| 357 | checkpointer.save_samples(global_step + global_step_offset) | ||
| 358 | accelerator.end_training() | ||
| 359 | |||
| 360 | except KeyboardInterrupt: | ||
| 361 | if accelerator.is_main_process: | ||
| 362 | print("Interrupted") | ||
| 363 | checkpointer.checkpoint(global_step + global_step_offset, "end") | ||
| 364 | accelerator.end_training() | ||
| 365 | quit() | ||
