diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-24 10:53:16 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-24 10:53:16 +0100 |
| commit | 95adaea8b55d8e3755c035758bc649ae22548572 (patch) | |
| tree | 80239f0bc55b99615718a935be2caa2e1e68e20a /training/strategy/lora.py | |
| parent | Bring back Perlin offset noise (diff) | |
| download | textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2 textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip | |
Refactoring, fixed Lora training
Diffstat (limited to 'training/strategy/lora.py')
| -rw-r--r-- | training/strategy/lora.py | 49 |
1 files changed, 11 insertions, 38 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1c8fad6..3971eae 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -10,18 +10,12 @@ from torch.utils.data import DataLoader | |||
| 10 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
| 11 | from transformers import CLIPTextModel | 11 | from transformers import CLIPTextModel |
| 12 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 12 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
| 13 | from peft import LoraConfig, LoraModel, get_peft_model_state_dict | 13 | from peft import get_peft_model_state_dict |
| 14 | from peft.tuners.lora import mark_only_lora_as_trainable | ||
| 15 | 14 | ||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 15 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 16 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 18 | 17 | ||
| 19 | 18 | ||
| 20 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | ||
| 21 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | ||
| 22 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] | ||
| 23 | |||
| 24 | |||
| 25 | def lora_strategy_callbacks( | 19 | def lora_strategy_callbacks( |
| 26 | accelerator: Accelerator, | 20 | accelerator: Accelerator, |
| 27 | unet: UNet2DConditionModel, | 21 | unet: UNet2DConditionModel, |
| @@ -61,10 +55,6 @@ def lora_strategy_callbacks( | |||
| 61 | image_size=sample_image_size, | 55 | image_size=sample_image_size, |
| 62 | ) | 56 | ) |
| 63 | 57 | ||
| 64 | def on_prepare(): | ||
| 65 | mark_only_lora_as_trainable(unet.model, unet.peft_config.bias) | ||
| 66 | mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias) | ||
| 67 | |||
| 68 | def on_accum_model(): | 58 | def on_accum_model(): |
| 69 | return unet | 59 | return unet |
| 70 | 60 | ||
| @@ -93,15 +83,15 @@ def lora_strategy_callbacks( | |||
| 93 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 83 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 94 | 84 | ||
| 95 | lora_config = {} | 85 | lora_config = {} |
| 96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) | 86 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
| 97 | lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) | 87 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
| 98 | 88 | ||
| 99 | text_encoder_state_dict = get_peft_model_state_dict( | 89 | text_encoder_state_dict = get_peft_model_state_dict( |
| 100 | text_encoder, state_dict=accelerator.get_state_dict(text_encoder) | 90 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
| 101 | ) | 91 | ) |
| 102 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 92 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} |
| 103 | state_dict.update(text_encoder_state_dict) | 93 | state_dict.update(text_encoder_state_dict) |
| 104 | lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) | 94 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
| 105 | 95 | ||
| 106 | accelerator.print(state_dict) | 96 | accelerator.print(state_dict) |
| 107 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") | 97 | accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") |
| @@ -111,11 +101,16 @@ def lora_strategy_callbacks( | |||
| 111 | 101 | ||
| 112 | @torch.no_grad() | 102 | @torch.no_grad() |
| 113 | def on_sample(step): | 103 | def on_sample(step): |
| 104 | vae_dtype = vae.dtype | ||
| 105 | vae.to(dtype=text_encoder.dtype) | ||
| 106 | |||
| 114 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 107 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 115 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 108 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 116 | 109 | ||
| 117 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 110 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
| 118 | 111 | ||
| 112 | vae.to(dtype=vae_dtype) | ||
| 113 | |||
| 119 | del unet_ | 114 | del unet_ |
| 120 | del text_encoder_ | 115 | del text_encoder_ |
| 121 | 116 | ||
| @@ -123,7 +118,6 @@ def lora_strategy_callbacks( | |||
| 123 | torch.cuda.empty_cache() | 118 | torch.cuda.empty_cache() |
| 124 | 119 | ||
| 125 | return TrainingCallbacks( | 120 | return TrainingCallbacks( |
| 126 | on_prepare=on_prepare, | ||
| 127 | on_accum_model=on_accum_model, | 121 | on_accum_model=on_accum_model, |
| 128 | on_train=on_train, | 122 | on_train=on_train, |
| 129 | on_eval=on_eval, | 123 | on_eval=on_eval, |
| @@ -147,28 +141,7 @@ def lora_prepare( | |||
| 147 | lora_bias: str = "none", | 141 | lora_bias: str = "none", |
| 148 | **kwargs | 142 | **kwargs |
| 149 | ): | 143 | ): |
| 150 | unet_config = LoraConfig( | 144 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) |
| 151 | r=lora_rank, | ||
| 152 | lora_alpha=lora_alpha, | ||
| 153 | target_modules=UNET_TARGET_MODULES, | ||
| 154 | lora_dropout=lora_dropout, | ||
| 155 | bias=lora_bias, | ||
| 156 | ) | ||
| 157 | unet = LoraModel(unet_config, unet) | ||
| 158 | |||
| 159 | text_encoder_config = LoraConfig( | ||
| 160 | r=lora_rank, | ||
| 161 | lora_alpha=lora_alpha, | ||
| 162 | target_modules=TEXT_ENCODER_TARGET_MODULES, | ||
| 163 | lora_dropout=lora_dropout, | ||
| 164 | bias=lora_bias, | ||
| 165 | ) | ||
| 166 | text_encoder = LoraModel(text_encoder_config, text_encoder) | ||
| 167 | |||
| 168 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 169 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
| 170 | |||
| 171 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} | ||
| 172 | 145 | ||
| 173 | 146 | ||
| 174 | lora_strategy = TrainingStrategy( | 147 | lora_strategy = TrainingStrategy( |
