summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
committerVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
commit95adaea8b55d8e3755c035758bc649ae22548572 (patch)
tree80239f0bc55b99615718a935be2caa2e1e68e20a /training/strategy/lora.py
parentBring back Perlin offset noise (diff)
downloadtextual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip
Refactoring, fixed Lora training
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py49
1 files changed, 11 insertions, 38 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1c8fad6..3971eae 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -10,18 +10,12 @@ from torch.utils.data import DataLoader
10from accelerate import Accelerator 10from accelerate import Accelerator
11from transformers import CLIPTextModel 11from transformers import CLIPTextModel
12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
13from peft import LoraConfig, LoraModel, get_peft_model_state_dict 13from peft import get_peft_model_state_dict
14from peft.tuners.lora import mark_only_lora_as_trainable
15 14
16from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
18 17
19 18
20# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
21UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
22TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
23
24
25def lora_strategy_callbacks( 19def lora_strategy_callbacks(
26 accelerator: Accelerator, 20 accelerator: Accelerator,
27 unet: UNet2DConditionModel, 21 unet: UNet2DConditionModel,
@@ -61,10 +55,6 @@ def lora_strategy_callbacks(
61 image_size=sample_image_size, 55 image_size=sample_image_size,
62 ) 56 )
63 57
64 def on_prepare():
65 mark_only_lora_as_trainable(unet.model, unet.peft_config.bias)
66 mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias)
67
68 def on_accum_model(): 58 def on_accum_model():
69 return unet 59 return unet
70 60
@@ -93,15 +83,15 @@ def lora_strategy_callbacks(
93 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 83 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
94 84
95 lora_config = {} 85 lora_config = {}
96 state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) 86 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
97 lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) 87 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
98 88
99 text_encoder_state_dict = get_peft_model_state_dict( 89 text_encoder_state_dict = get_peft_model_state_dict(
100 text_encoder, state_dict=accelerator.get_state_dict(text_encoder) 90 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
101 ) 91 )
102 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 92 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
103 state_dict.update(text_encoder_state_dict) 93 state_dict.update(text_encoder_state_dict)
104 lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) 94 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
105 95
106 accelerator.print(state_dict) 96 accelerator.print(state_dict)
107 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") 97 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
@@ -111,11 +101,16 @@ def lora_strategy_callbacks(
111 101
112 @torch.no_grad() 102 @torch.no_grad()
113 def on_sample(step): 103 def on_sample(step):
104 vae_dtype = vae.dtype
105 vae.to(dtype=text_encoder.dtype)
106
114 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 107 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
115 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 108 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
116 109
117 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 110 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
118 111
112 vae.to(dtype=vae_dtype)
113
119 del unet_ 114 del unet_
120 del text_encoder_ 115 del text_encoder_
121 116
@@ -123,7 +118,6 @@ def lora_strategy_callbacks(
123 torch.cuda.empty_cache() 118 torch.cuda.empty_cache()
124 119
125 return TrainingCallbacks( 120 return TrainingCallbacks(
126 on_prepare=on_prepare,
127 on_accum_model=on_accum_model, 121 on_accum_model=on_accum_model,
128 on_train=on_train, 122 on_train=on_train,
129 on_eval=on_eval, 123 on_eval=on_eval,
@@ -147,28 +141,7 @@ def lora_prepare(
147 lora_bias: str = "none", 141 lora_bias: str = "none",
148 **kwargs 142 **kwargs
149): 143):
150 unet_config = LoraConfig( 144 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},)
151 r=lora_rank,
152 lora_alpha=lora_alpha,
153 target_modules=UNET_TARGET_MODULES,
154 lora_dropout=lora_dropout,
155 bias=lora_bias,
156 )
157 unet = LoraModel(unet_config, unet)
158
159 text_encoder_config = LoraConfig(
160 r=lora_rank,
161 lora_alpha=lora_alpha,
162 target_modules=TEXT_ENCODER_TARGET_MODULES,
163 lora_dropout=lora_dropout,
164 bias=lora_bias,
165 )
166 text_encoder = LoraModel(text_encoder_config, text_encoder)
167
168 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
169 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
170
171 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
172 145
173 146
174lora_strategy = TrainingStrategy( 147lora_strategy = TrainingStrategy(