diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 70 |
1 files changed, 59 insertions, 11 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cab5e4c..aa75bec 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -2,6 +2,7 @@ from typing import Optional | |||
2 | from functools import partial | 2 | from functools import partial |
3 | from contextlib import contextmanager | 3 | from contextlib import contextmanager |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | import itertools | ||
5 | 6 | ||
6 | import torch | 7 | import torch |
7 | from torch.utils.data import DataLoader | 8 | from torch.utils.data import DataLoader |
@@ -9,12 +10,18 @@ from torch.utils.data import DataLoader | |||
9 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
10 | from transformers import CLIPTextModel | 11 | from transformers import CLIPTextModel |
11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 12 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
12 | from diffusers.loaders import AttnProcsLayers | 13 | from peft import LoraConfig, LoraModel, get_peft_model_state_dict |
14 | from peft.tuners.lora import mark_only_lora_as_trainable | ||
13 | 15 | ||
14 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
15 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
16 | 18 | ||
17 | 19 | ||
20 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
21 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
22 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
23 | |||
24 | |||
18 | def lora_strategy_callbacks( | 25 | def lora_strategy_callbacks( |
19 | accelerator: Accelerator, | 26 | accelerator: Accelerator, |
20 | unet: UNet2DConditionModel, | 27 | unet: UNet2DConditionModel, |
@@ -27,7 +34,6 @@ def lora_strategy_callbacks( | |||
27 | sample_output_dir: Path, | 34 | sample_output_dir: Path, |
28 | checkpoint_output_dir: Path, | 35 | checkpoint_output_dir: Path, |
29 | seed: int, | 36 | seed: int, |
30 | lora_layers: AttnProcsLayers, | ||
31 | max_grad_norm: float = 1.0, | 37 | max_grad_norm: float = 1.0, |
32 | sample_batch_size: int = 1, | 38 | sample_batch_size: int = 1, |
33 | sample_num_batches: int = 1, | 39 | sample_num_batches: int = 1, |
@@ -57,7 +63,8 @@ def lora_strategy_callbacks( | |||
57 | ) | 63 | ) |
58 | 64 | ||
59 | def on_prepare(): | 65 | def on_prepare(): |
60 | lora_layers.requires_grad_(True) | 66 | mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) |
67 | mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) | ||
61 | 68 | ||
62 | def on_accum_model(): | 69 | def on_accum_model(): |
63 | return unet | 70 | return unet |
@@ -73,24 +80,44 @@ def lora_strategy_callbacks( | |||
73 | yield | 80 | yield |
74 | 81 | ||
75 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(lr: float, epoch: int): |
76 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | 83 | accelerator.clip_grad_norm_( |
84 | itertools.chain(unet.parameters(), text_encoder.parameters()), | ||
85 | max_grad_norm | ||
86 | ) | ||
77 | 87 | ||
78 | @torch.no_grad() | 88 | @torch.no_grad() |
79 | def on_checkpoint(step, postfix): | 89 | def on_checkpoint(step, postfix): |
80 | print(f"Saving checkpoint for step {step}...") | 90 | print(f"Saving checkpoint for step {step}...") |
81 | 91 | ||
82 | unet_ = accelerator.unwrap_model(unet, False) | 92 | unet_ = accelerator.unwrap_model(unet, False) |
83 | unet_.save_attn_procs( | 93 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) |
84 | checkpoint_output_dir / f"{step}_{postfix}", | 94 | |
85 | safe_serialization=True | 95 | lora_config = {} |
96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) | ||
97 | lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) | ||
98 | |||
99 | text_encoder_state_dict = get_peft_model_state_dict( | ||
100 | text_encoder, state_dict=accelerator.get_state_dict(text_encoder) | ||
86 | ) | 101 | ) |
102 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | ||
103 | state_dict.update(text_encoder_state_dict) | ||
104 | lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) | ||
105 | |||
106 | accelerator.print(state_dict) | ||
107 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | ||
108 | |||
87 | del unet_ | 109 | del unet_ |
110 | del text_encoder_ | ||
88 | 111 | ||
89 | @torch.no_grad() | 112 | @torch.no_grad() |
90 | def on_sample(step): | 113 | def on_sample(step): |
91 | unet_ = accelerator.unwrap_model(unet, False) | 114 | unet_ = accelerator.unwrap_model(unet, False) |
115 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | ||
116 | |||
92 | save_samples_(step=step, unet=unet_) | 117 | save_samples_(step=step, unet=unet_) |
118 | |||
93 | del unet_ | 119 | del unet_ |
120 | del text_encoder_ | ||
94 | 121 | ||
95 | if torch.cuda.is_available(): | 122 | if torch.cuda.is_available(): |
96 | torch.cuda.empty_cache() | 123 | torch.cuda.empty_cache() |
@@ -114,13 +141,34 @@ def lora_prepare( | |||
114 | train_dataloader: DataLoader, | 141 | train_dataloader: DataLoader, |
115 | val_dataloader: Optional[DataLoader], | 142 | val_dataloader: Optional[DataLoader], |
116 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 143 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
117 | lora_layers: AttnProcsLayers, | 144 | lora_rank: int = 4, |
145 | lora_alpha: int = 32, | ||
146 | lora_dropout: float = 0, | ||
147 | lora_bias: str = "none", | ||
118 | **kwargs | 148 | **kwargs |
119 | ): | 149 | ): |
120 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 150 | unet_config = LoraConfig( |
121 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 151 | r=lora_rank, |
152 | lora_alpha=lora_alpha, | ||
153 | target_modules=UNET_TARGET_MODULES, | ||
154 | lora_dropout=lora_dropout, | ||
155 | bias=lora_bias, | ||
156 | ) | ||
157 | unet = LoraModel(unet_config, unet) | ||
158 | |||
159 | text_encoder_config = LoraConfig( | ||
160 | r=lora_rank, | ||
161 | lora_alpha=lora_alpha, | ||
162 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
163 | lora_dropout=lora_dropout, | ||
164 | bias=lora_bias, | ||
165 | ) | ||
166 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
167 | |||
168 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
169 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
122 | 170 | ||
123 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | 171 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
124 | 172 | ||
125 | 173 | ||
126 | lora_strategy = TrainingStrategy( | 174 | lora_strategy = TrainingStrategy( |