diff options
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_lora.py | 36 | ||||
| -rw-r--r-- | train_ti.py | 4 | ||||
| -rw-r--r-- | training/functional.py | 35 | ||||
| -rw-r--r-- | training/strategy/lora.py | 70 |
5 files changed, 97 insertions, 50 deletions
diff --git a/environment.yaml b/environment.yaml index 9355f37..db43bd5 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -17,9 +17,11 @@ dependencies: | |||
| 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 17 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 18 | - accelerate==0.17.1 | 18 | - accelerate==0.17.1 |
| 19 | - bitsandbytes==0.37.1 | 19 | - bitsandbytes==0.37.1 |
| 20 | - peft==0.2.0 | ||
| 20 | - python-slugify>=6.1.2 | 21 | - python-slugify>=6.1.2 |
| 21 | - safetensors==0.3.0 | 22 | - safetensors==0.3.0 |
| 22 | - setuptools==65.6.3 | 23 | - setuptools==65.6.3 |
| 23 | - test-tube>=0.7.5 | 24 | - test-tube>=0.7.5 |
| 24 | - transformers==4.27.1 | 25 | - transformers==4.27.1 |
| 25 | - triton==2.0.0 | 26 | - triton==2.0.0 |
| 27 | - xformers==0.0.17.dev480 | ||
diff --git a/train_lora.py b/train_lora.py index e65e7be..2a798f3 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -12,8 +12,6 @@ from accelerate import Accelerator | |||
| 12 | from accelerate.logging import get_logger | 12 | from accelerate.logging import get_logger |
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | ||
| 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
| 17 | 15 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 16 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 17 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -426,34 +424,16 @@ def main(): | |||
| 426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 427 | args.pretrained_model_name_or_path) | 425 | args.pretrained_model_name_or_path) |
| 428 | 426 | ||
| 429 | unet.to(accelerator.device, dtype=weight_dtype) | 427 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 430 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 428 | tokenizer.set_dropout(args.vector_dropout) |
| 431 | |||
| 432 | lora_attn_procs = {} | ||
| 433 | for name in unet.attn_processors.keys(): | ||
| 434 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
| 435 | if name.startswith("mid_block"): | ||
| 436 | hidden_size = unet.config.block_out_channels[-1] | ||
| 437 | elif name.startswith("up_blocks"): | ||
| 438 | block_id = int(name[len("up_blocks.")]) | ||
| 439 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
| 440 | elif name.startswith("down_blocks"): | ||
| 441 | block_id = int(name[len("down_blocks.")]) | ||
| 442 | hidden_size = unet.config.block_out_channels[block_id] | ||
| 443 | |||
| 444 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
| 445 | hidden_size=hidden_size, | ||
| 446 | cross_attention_dim=cross_attention_dim, | ||
| 447 | rank=args.lora_rank | ||
| 448 | ) | ||
| 449 | |||
| 450 | unet.set_attn_processor(lora_attn_procs) | ||
| 451 | 429 | ||
| 452 | vae.enable_slicing() | 430 | vae.enable_slicing() |
| 453 | vae.set_use_memory_efficient_attention_xformers(True) | 431 | vae.set_use_memory_efficient_attention_xformers(True) |
| 454 | unet.enable_xformers_memory_efficient_attention() | 432 | unet.enable_xformers_memory_efficient_attention() |
| 455 | 433 | ||
| 456 | lora_layers = AttnProcsLayers(unet.attn_processors) | 434 | if args.gradient_checkpointing: |
| 435 | unet.enable_gradient_checkpointing() | ||
| 436 | text_encoder.gradient_checkpointing_enable() | ||
| 457 | 437 | ||
| 458 | if args.embeddings_dir is not None: | 438 | if args.embeddings_dir is not None: |
| 459 | embeddings_dir = Path(args.embeddings_dir) | 439 | embeddings_dir = Path(args.embeddings_dir) |
| @@ -505,7 +485,6 @@ def main(): | |||
| 505 | unet=unet, | 485 | unet=unet, |
| 506 | text_encoder=text_encoder, | 486 | text_encoder=text_encoder, |
| 507 | vae=vae, | 487 | vae=vae, |
| 508 | lora_layers=lora_layers, | ||
| 509 | noise_scheduler=noise_scheduler, | 488 | noise_scheduler=noise_scheduler, |
| 510 | dtype=weight_dtype, | 489 | dtype=weight_dtype, |
| 511 | with_prior_preservation=args.num_class_images != 0, | 490 | with_prior_preservation=args.num_class_images != 0, |
| @@ -540,7 +519,10 @@ def main(): | |||
| 540 | datamodule.setup() | 519 | datamodule.setup() |
| 541 | 520 | ||
| 542 | optimizer = create_optimizer( | 521 | optimizer = create_optimizer( |
| 543 | lora_layers.parameters(), | 522 | itertools.chain( |
| 523 | unet.parameters(), | ||
| 524 | text_encoder.parameters(), | ||
| 525 | ), | ||
| 544 | lr=args.learning_rate, | 526 | lr=args.learning_rate, |
| 545 | ) | 527 | ) |
| 546 | 528 | ||
diff --git a/train_ti.py b/train_ti.py index fd23517..2e92ae4 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -547,8 +547,8 @@ def main(): | |||
| 547 | tokenizer.set_dropout(args.vector_dropout) | 547 | tokenizer.set_dropout(args.vector_dropout) |
| 548 | 548 | ||
| 549 | vae.enable_slicing() | 549 | vae.enable_slicing() |
| 550 | # vae.set_use_memory_efficient_attention_xformers(True) | 550 | vae.set_use_memory_efficient_attention_xformers(True) |
| 551 | # unet.enable_xformers_memory_efficient_attention() | 551 | unet.enable_xformers_memory_efficient_attention() |
| 552 | # unet = torch.compile(unet) | 552 | # unet = torch.compile(unet) |
| 553 | 553 | ||
| 554 | if args.gradient_checkpointing: | 554 | if args.gradient_checkpointing: |
diff --git a/training/functional.py b/training/functional.py index 8dc2b9f..43ee356 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -251,6 +251,25 @@ def add_placeholder_tokens( | |||
| 251 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
| 252 | 252 | ||
| 253 | 253 | ||
| 254 | def snr_weight(noisy_latents, latents, gamma): | ||
| 255 | if gamma: | ||
| 256 | sigma = torch.sub(noisy_latents, latents) | ||
| 257 | zeros = torch.zeros_like(sigma) | ||
| 258 | alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | ||
| 259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | ||
| 260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) | ||
| 261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) | ||
| 262 | snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() | ||
| 263 | return snr_weight | ||
| 264 | |||
| 265 | return torch.tensor( | ||
| 266 | [1], | ||
| 267 | dtype=latents.dtype, | ||
| 268 | layout=latents.layout, | ||
| 269 | device=latents.device, | ||
| 270 | ) | ||
| 271 | |||
| 272 | |||
| 254 | def loss_step( | 273 | def loss_step( |
| 255 | vae: AutoencoderKL, | 274 | vae: AutoencoderKL, |
| 256 | noise_scheduler: SchedulerMixin, | 275 | noise_scheduler: SchedulerMixin, |
| @@ -308,21 +327,13 @@ def loss_step( | |||
| 308 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 327 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 309 | 328 | ||
| 310 | # Get the target for loss depending on the prediction type | 329 | # Get the target for loss depending on the prediction type |
| 311 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() | ||
| 312 | snr = alpha_t / (1 - alpha_t) | ||
| 313 | min_snr = snr.clamp(max=min_snr_gamma) | ||
| 314 | |||
| 315 | if noise_scheduler.config.prediction_type == "epsilon": | 330 | if noise_scheduler.config.prediction_type == "epsilon": |
| 316 | target = noise | 331 | target = noise |
| 317 | loss_weight = min_snr / snr | ||
| 318 | elif noise_scheduler.config.prediction_type == "v_prediction": | 332 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 319 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 333 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 320 | loss_weight = min_snr / (snr + 1) | ||
| 321 | else: | 334 | else: |
| 322 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 335 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 323 | 336 | ||
| 324 | loss_weight = loss_weight[..., None, None, None] | ||
| 325 | |||
| 326 | if with_prior_preservation: | 337 | if with_prior_preservation: |
| 327 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 338 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 328 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 339 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| @@ -339,7 +350,11 @@ def loss_step( | |||
| 339 | else: | 350 | else: |
| 340 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 341 | 352 | ||
| 342 | loss = (loss_weight * loss).mean([1, 2, 3]).mean() | 353 | loss = loss.mean([1, 2, 3]) |
| 354 | |||
| 355 | loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) | ||
| 356 | loss = (loss_weight * loss).mean() | ||
| 357 | |||
| 343 | acc = (model_pred == target).float().mean() | 358 | acc = (model_pred == target).float().mean() |
| 344 | 359 | ||
| 345 | return loss, acc, bsz | 360 | return loss, acc, bsz |
| @@ -412,7 +427,7 @@ def train_loop( | |||
| 412 | try: | 427 | try: |
| 413 | for epoch in range(num_epochs): | 428 | for epoch in range(num_epochs): |
| 414 | if accelerator.is_main_process: | 429 | if accelerator.is_main_process: |
| 415 | if epoch % sample_frequency == 0 and epoch != 0: | 430 | if epoch % sample_frequency == 0: |
| 416 | local_progress_bar.clear() | 431 | local_progress_bar.clear() |
| 417 | global_progress_bar.clear() | 432 | global_progress_bar.clear() |
| 418 | 433 | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cab5e4c..aa75bec 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -2,6 +2,7 @@ from typing import Optional | |||
| 2 | from functools import partial | 2 | from functools import partial |
| 3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager |
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | import itertools | ||
| 5 | 6 | ||
| 6 | import torch | 7 | import torch |
| 7 | from torch.utils.data import DataLoader | 8 | from torch.utils.data import DataLoader |
| @@ -9,12 +10,18 @@ from torch.utils.data import DataLoader | |||
| 9 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
| 10 | from transformers import CLIPTextModel | 11 | from transformers import CLIPTextModel |
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 12 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
| 12 | from diffusers.loaders import AttnProcsLayers | 13 | from peft import LoraConfig, LoraModel, get_peft_model_state_dict |
| 14 | from peft.tuners.lora import mark_only_lora_as_trainable | ||
| 13 | 15 | ||
| 14 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 15 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 16 | 18 | ||
| 17 | 19 | ||
| 20 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
| 21 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
| 22 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
| 23 | |||
| 24 | |||
| 18 | def lora_strategy_callbacks( | 25 | def lora_strategy_callbacks( |
| 19 | accelerator: Accelerator, | 26 | accelerator: Accelerator, |
| 20 | unet: UNet2DConditionModel, | 27 | unet: UNet2DConditionModel, |
| @@ -27,7 +34,6 @@ def lora_strategy_callbacks( | |||
| 27 | sample_output_dir: Path, | 34 | sample_output_dir: Path, |
| 28 | checkpoint_output_dir: Path, | 35 | checkpoint_output_dir: Path, |
| 29 | seed: int, | 36 | seed: int, |
| 30 | lora_layers: AttnProcsLayers, | ||
| 31 | max_grad_norm: float = 1.0, | 37 | max_grad_norm: float = 1.0, |
| 32 | sample_batch_size: int = 1, | 38 | sample_batch_size: int = 1, |
| 33 | sample_num_batches: int = 1, | 39 | sample_num_batches: int = 1, |
| @@ -57,7 +63,8 @@ def lora_strategy_callbacks( | |||
| 57 | ) | 63 | ) |
| 58 | 64 | ||
| 59 | def on_prepare(): | 65 | def on_prepare(): |
| 60 | lora_layers.requires_grad_(True) | 66 | mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) |
| 67 | mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) | ||
| 61 | 68 | ||
| 62 | def on_accum_model(): | 69 | def on_accum_model(): |
| 63 | return unet | 70 | return unet |
| @@ -73,24 +80,44 @@ def lora_strategy_callbacks( | |||
| 73 | yield | 80 | yield |
| 74 | 81 | ||
| 75 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(lr: float, epoch: int): |
| 76 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | 83 | accelerator.clip_grad_norm_( |
| 84 | itertools.chain(unet.parameters(), text_encoder.parameters()), | ||
| 85 | max_grad_norm | ||
| 86 | ) | ||
| 77 | 87 | ||
| 78 | @torch.no_grad() | 88 | @torch.no_grad() |
| 79 | def on_checkpoint(step, postfix): | 89 | def on_checkpoint(step, postfix): |
| 80 | print(f"Saving checkpoint for step {step}...") | 90 | print(f"Saving checkpoint for step {step}...") |
| 81 | 91 | ||
| 82 | unet_ = accelerator.unwrap_model(unet, False) | 92 | unet_ = accelerator.unwrap_model(unet, False) |
| 83 | unet_.save_attn_procs( | 93 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
| 84 | checkpoint_output_dir / f"{step}_{postfix}", | 94 | |
| 85 | safe_serialization=True | 95 | lora_config = {} |
| 96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) | ||
| 97 | lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) | ||
| 98 | |||
| 99 | text_encoder_state_dict = get_peft_model_state_dict( | ||
| 100 | text_encoder, state_dict=accelerator.get_state_dict(text_encoder) | ||
| 86 | ) | 101 | ) |
| 102 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | ||
| 103 | state_dict.update(text_encoder_state_dict) | ||
| 104 | lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) | ||
| 105 | |||
| 106 | accelerator.print(state_dict) | ||
| 107 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | ||
| 108 | |||
| 87 | del unet_ | 109 | del unet_ |
| 110 | del text_encoder_ | ||
| 88 | 111 | ||
| 89 | @torch.no_grad() | 112 | @torch.no_grad() |
| 90 | def on_sample(step): | 113 | def on_sample(step): |
| 91 | unet_ = accelerator.unwrap_model(unet, False) | 114 | unet_ = accelerator.unwrap_model(unet, False) |
| 115 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | ||
| 116 | |||
| 92 | save_samples_(step=step, unet=unet_) | 117 | save_samples_(step=step, unet=unet_) |
| 118 | |||
| 93 | del unet_ | 119 | del unet_ |
| 120 | del text_encoder_ | ||
| 94 | 121 | ||
| 95 | if torch.cuda.is_available(): | 122 | if torch.cuda.is_available(): |
| 96 | torch.cuda.empty_cache() | 123 | torch.cuda.empty_cache() |
| @@ -114,13 +141,34 @@ def lora_prepare( | |||
| 114 | train_dataloader: DataLoader, | 141 | train_dataloader: DataLoader, |
| 115 | val_dataloader: Optional[DataLoader], | 142 | val_dataloader: Optional[DataLoader], |
| 116 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 143 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 117 | lora_layers: AttnProcsLayers, | 144 | lora_rank: int = 4, |
| 145 | lora_alpha: int = 32, | ||
| 146 | lora_dropout: float = 0, | ||
| 147 | lora_bias: str = "none", | ||
| 118 | **kwargs | 148 | **kwargs |
| 119 | ): | 149 | ): |
| 120 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 150 | unet_config = LoraConfig( |
| 121 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 151 | r=lora_rank, |
| 152 | lora_alpha=lora_alpha, | ||
| 153 | target_modules=UNET_TARGET_MODULES, | ||
| 154 | lora_dropout=lora_dropout, | ||
| 155 | bias=lora_bias, | ||
| 156 | ) | ||
| 157 | unet = LoraModel(unet_config, unet) | ||
| 158 | |||
| 159 | text_encoder_config = LoraConfig( | ||
| 160 | r=lora_rank, | ||
| 161 | lora_alpha=lora_alpha, | ||
| 162 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
| 163 | lora_dropout=lora_dropout, | ||
| 164 | bias=lora_bias, | ||
| 165 | ) | ||
| 166 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
| 167 | |||
| 168 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 169 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
| 122 | 170 | ||
| 123 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | 171 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
| 124 | 172 | ||
| 125 | 173 | ||
| 126 | lora_strategy = TrainingStrategy( | 174 | lora_strategy = TrainingStrategy( |
