summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml2
-rw-r--r--train_lora.py36
-rw-r--r--train_ti.py4
-rw-r--r--training/functional.py35
-rw-r--r--training/strategy/lora.py70
5 files changed, 97 insertions, 50 deletions
diff --git a/environment.yaml b/environment.yaml
index 9355f37..db43bd5 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -17,9 +17,11 @@ dependencies:
17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 17 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
18 - accelerate==0.17.1 18 - accelerate==0.17.1
19 - bitsandbytes==0.37.1 19 - bitsandbytes==0.37.1
20 - peft==0.2.0
20 - python-slugify>=6.1.2 21 - python-slugify>=6.1.2
21 - safetensors==0.3.0 22 - safetensors==0.3.0
22 - setuptools==65.6.3 23 - setuptools==65.6.3
23 - test-tube>=0.7.5 24 - test-tube>=0.7.5
24 - transformers==4.27.1 25 - transformers==4.27.1
25 - triton==2.0.0 26 - triton==2.0.0
27 - xformers==0.0.17.dev480
diff --git a/train_lora.py b/train_lora.py
index e65e7be..2a798f3 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -12,8 +12,6 @@ from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRACrossAttnProcessor
17 15
18from util.files import load_config, load_embeddings_from_dir 16from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 17from data.csv import VlpnDataModule, keyword_filter
@@ -426,34 +424,16 @@ def main():
426 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 424 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
427 args.pretrained_model_name_or_path) 425 args.pretrained_model_name_or_path)
428 426
429 unet.to(accelerator.device, dtype=weight_dtype) 427 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
430 text_encoder.to(accelerator.device, dtype=weight_dtype) 428 tokenizer.set_dropout(args.vector_dropout)
431
432 lora_attn_procs = {}
433 for name in unet.attn_processors.keys():
434 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
435 if name.startswith("mid_block"):
436 hidden_size = unet.config.block_out_channels[-1]
437 elif name.startswith("up_blocks"):
438 block_id = int(name[len("up_blocks.")])
439 hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
440 elif name.startswith("down_blocks"):
441 block_id = int(name[len("down_blocks.")])
442 hidden_size = unet.config.block_out_channels[block_id]
443
444 lora_attn_procs[name] = LoRACrossAttnProcessor(
445 hidden_size=hidden_size,
446 cross_attention_dim=cross_attention_dim,
447 rank=args.lora_rank
448 )
449
450 unet.set_attn_processor(lora_attn_procs)
451 429
452 vae.enable_slicing() 430 vae.enable_slicing()
453 vae.set_use_memory_efficient_attention_xformers(True) 431 vae.set_use_memory_efficient_attention_xformers(True)
454 unet.enable_xformers_memory_efficient_attention() 432 unet.enable_xformers_memory_efficient_attention()
455 433
456 lora_layers = AttnProcsLayers(unet.attn_processors) 434 if args.gradient_checkpointing:
435 unet.enable_gradient_checkpointing()
436 text_encoder.gradient_checkpointing_enable()
457 437
458 if args.embeddings_dir is not None: 438 if args.embeddings_dir is not None:
459 embeddings_dir = Path(args.embeddings_dir) 439 embeddings_dir = Path(args.embeddings_dir)
@@ -505,7 +485,6 @@ def main():
505 unet=unet, 485 unet=unet,
506 text_encoder=text_encoder, 486 text_encoder=text_encoder,
507 vae=vae, 487 vae=vae,
508 lora_layers=lora_layers,
509 noise_scheduler=noise_scheduler, 488 noise_scheduler=noise_scheduler,
510 dtype=weight_dtype, 489 dtype=weight_dtype,
511 with_prior_preservation=args.num_class_images != 0, 490 with_prior_preservation=args.num_class_images != 0,
@@ -540,7 +519,10 @@ def main():
540 datamodule.setup() 519 datamodule.setup()
541 520
542 optimizer = create_optimizer( 521 optimizer = create_optimizer(
543 lora_layers.parameters(), 522 itertools.chain(
523 unet.parameters(),
524 text_encoder.parameters(),
525 ),
544 lr=args.learning_rate, 526 lr=args.learning_rate,
545 ) 527 )
546 528
diff --git a/train_ti.py b/train_ti.py
index fd23517..2e92ae4 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -547,8 +547,8 @@ def main():
547 tokenizer.set_dropout(args.vector_dropout) 547 tokenizer.set_dropout(args.vector_dropout)
548 548
549 vae.enable_slicing() 549 vae.enable_slicing()
550 # vae.set_use_memory_efficient_attention_xformers(True) 550 vae.set_use_memory_efficient_attention_xformers(True)
551 # unet.enable_xformers_memory_efficient_attention() 551 unet.enable_xformers_memory_efficient_attention()
552 # unet = torch.compile(unet) 552 # unet = torch.compile(unet)
553 553
554 if args.gradient_checkpointing: 554 if args.gradient_checkpointing:
diff --git a/training/functional.py b/training/functional.py
index 8dc2b9f..43ee356 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -251,6 +251,25 @@ def add_placeholder_tokens(
251 return placeholder_token_ids, initializer_token_ids 251 return placeholder_token_ids, initializer_token_ids
252 252
253 253
254def snr_weight(noisy_latents, latents, gamma):
255 if gamma:
256 sigma = torch.sub(noisy_latents, latents)
257 zeros = torch.zeros_like(sigma)
258 alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
259 sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
260 snr = torch.div(alpha_mean_sq, sigma_mean_sq)
261 gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
262 snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
263 return snr_weight
264
265 return torch.tensor(
266 [1],
267 dtype=latents.dtype,
268 layout=latents.layout,
269 device=latents.device,
270 )
271
272
254def loss_step( 273def loss_step(
255 vae: AutoencoderKL, 274 vae: AutoencoderKL,
256 noise_scheduler: SchedulerMixin, 275 noise_scheduler: SchedulerMixin,
@@ -308,21 +327,13 @@ def loss_step(
308 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 327 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
309 328
310 # Get the target for loss depending on the prediction type 329 # Get the target for loss depending on the prediction type
311 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
312 snr = alpha_t / (1 - alpha_t)
313 min_snr = snr.clamp(max=min_snr_gamma)
314
315 if noise_scheduler.config.prediction_type == "epsilon": 330 if noise_scheduler.config.prediction_type == "epsilon":
316 target = noise 331 target = noise
317 loss_weight = min_snr / snr
318 elif noise_scheduler.config.prediction_type == "v_prediction": 332 elif noise_scheduler.config.prediction_type == "v_prediction":
319 target = noise_scheduler.get_velocity(latents, noise, timesteps) 333 target = noise_scheduler.get_velocity(latents, noise, timesteps)
320 loss_weight = min_snr / (snr + 1)
321 else: 334 else:
322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 335 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
323 336
324 loss_weight = loss_weight[..., None, None, None]
325
326 if with_prior_preservation: 337 if with_prior_preservation:
327 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 338 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
328 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 339 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
@@ -339,7 +350,11 @@ def loss_step(
339 else: 350 else:
340 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 351 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
341 352
342 loss = (loss_weight * loss).mean([1, 2, 3]).mean() 353 loss = loss.mean([1, 2, 3])
354
355 loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma)
356 loss = (loss_weight * loss).mean()
357
343 acc = (model_pred == target).float().mean() 358 acc = (model_pred == target).float().mean()
344 359
345 return loss, acc, bsz 360 return loss, acc, bsz
@@ -412,7 +427,7 @@ def train_loop(
412 try: 427 try:
413 for epoch in range(num_epochs): 428 for epoch in range(num_epochs):
414 if accelerator.is_main_process: 429 if accelerator.is_main_process:
415 if epoch % sample_frequency == 0 and epoch != 0: 430 if epoch % sample_frequency == 0:
416 local_progress_bar.clear() 431 local_progress_bar.clear()
417 global_progress_bar.clear() 432 global_progress_bar.clear()
418 433
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index cab5e4c..aa75bec 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -2,6 +2,7 @@ from typing import Optional
2from functools import partial 2from functools import partial
3from contextlib import contextmanager 3from contextlib import contextmanager
4from pathlib import Path 4from pathlib import Path
5import itertools
5 6
6import torch 7import torch
7from torch.utils.data import DataLoader 8from torch.utils.data import DataLoader
@@ -9,12 +10,18 @@ from torch.utils.data import DataLoader
9from accelerate import Accelerator 10from accelerate import Accelerator
10from transformers import CLIPTextModel 11from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.loaders import AttnProcsLayers 13from peft import LoraConfig, LoraModel, get_peft_model_state_dict
14from peft.tuners.lora import mark_only_lora_as_trainable
13 15
14from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
15from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
16 18
17 19
20# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
21UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
22TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
23
24
18def lora_strategy_callbacks( 25def lora_strategy_callbacks(
19 accelerator: Accelerator, 26 accelerator: Accelerator,
20 unet: UNet2DConditionModel, 27 unet: UNet2DConditionModel,
@@ -27,7 +34,6 @@ def lora_strategy_callbacks(
27 sample_output_dir: Path, 34 sample_output_dir: Path,
28 checkpoint_output_dir: Path, 35 checkpoint_output_dir: Path,
29 seed: int, 36 seed: int,
30 lora_layers: AttnProcsLayers,
31 max_grad_norm: float = 1.0, 37 max_grad_norm: float = 1.0,
32 sample_batch_size: int = 1, 38 sample_batch_size: int = 1,
33 sample_num_batches: int = 1, 39 sample_num_batches: int = 1,
@@ -57,7 +63,8 @@ def lora_strategy_callbacks(
57 ) 63 )
58 64
59 def on_prepare(): 65 def on_prepare():
60 lora_layers.requires_grad_(True) 66 mark_only_lora_as_trainable(unet.model, unet.peft_config.bias)
67 mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias)
61 68
62 def on_accum_model(): 69 def on_accum_model():
63 return unet 70 return unet
@@ -73,24 +80,44 @@ def lora_strategy_callbacks(
73 yield 80 yield
74 81
75 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(lr: float, epoch: int):
76 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) 83 accelerator.clip_grad_norm_(
84 itertools.chain(unet.parameters(), text_encoder.parameters()),
85 max_grad_norm
86 )
77 87
78 @torch.no_grad() 88 @torch.no_grad()
79 def on_checkpoint(step, postfix): 89 def on_checkpoint(step, postfix):
80 print(f"Saving checkpoint for step {step}...") 90 print(f"Saving checkpoint for step {step}...")
81 91
82 unet_ = accelerator.unwrap_model(unet, False) 92 unet_ = accelerator.unwrap_model(unet, False)
83 unet_.save_attn_procs( 93 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
84 checkpoint_output_dir / f"{step}_{postfix}", 94
85 safe_serialization=True 95 lora_config = {}
96 state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
97 lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
98
99 text_encoder_state_dict = get_peft_model_state_dict(
100 text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
86 ) 101 )
102 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
103 state_dict.update(text_encoder_state_dict)
104 lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)
105
106 accelerator.print(state_dict)
107 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
108
87 del unet_ 109 del unet_
110 del text_encoder_
88 111
89 @torch.no_grad() 112 @torch.no_grad()
90 def on_sample(step): 113 def on_sample(step):
91 unet_ = accelerator.unwrap_model(unet, False) 114 unet_ = accelerator.unwrap_model(unet, False)
115 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
116
92 save_samples_(step=step, unet=unet_) 117 save_samples_(step=step, unet=unet_)
118
93 del unet_ 119 del unet_
120 del text_encoder_
94 121
95 if torch.cuda.is_available(): 122 if torch.cuda.is_available():
96 torch.cuda.empty_cache() 123 torch.cuda.empty_cache()
@@ -114,13 +141,34 @@ def lora_prepare(
114 train_dataloader: DataLoader, 141 train_dataloader: DataLoader,
115 val_dataloader: Optional[DataLoader], 142 val_dataloader: Optional[DataLoader],
116 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 143 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
117 lora_layers: AttnProcsLayers, 144 lora_rank: int = 4,
145 lora_alpha: int = 32,
146 lora_dropout: float = 0,
147 lora_bias: str = "none",
118 **kwargs 148 **kwargs
119): 149):
120 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 150 unet_config = LoraConfig(
121 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) 151 r=lora_rank,
152 lora_alpha=lora_alpha,
153 target_modules=UNET_TARGET_MODULES,
154 lora_dropout=lora_dropout,
155 bias=lora_bias,
156 )
157 unet = LoraModel(unet_config, unet)
158
159 text_encoder_config = LoraConfig(
160 r=lora_rank,
161 lora_alpha=lora_alpha,
162 target_modules=TEXT_ENCODER_TARGET_MODULES,
163 lora_dropout=lora_dropout,
164 bias=lora_bias,
165 )
166 text_encoder = LoraModel(text_encoder_config, text_encoder)
167
168 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
169 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
122 170
123 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} 171 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
124 172
125 173
126lora_strategy = TrainingStrategy( 174lora_strategy = TrainingStrategy(