summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
committerVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
commit95adaea8b55d8e3755c035758bc649ae22548572 (patch)
tree80239f0bc55b99615718a935be2caa2e1e68e20a /training
parentBring back Perlin offset noise (diff)
downloadtextual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip
Refactoring, fixed Lora training
Diffstat (limited to 'training')
-rw-r--r--training/functional.py9
-rw-r--r--training/strategy/dreambooth.py17
-rw-r--r--training/strategy/lora.py49
-rw-r--r--training/strategy/ti.py22
4 files changed, 32 insertions, 65 deletions
diff --git a/training/functional.py b/training/functional.py
index a5b339d..ee73ab2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,6 @@ def const(result=None):
34 34
35@dataclass 35@dataclass
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_prepare: Callable[[], None] = const()
38 on_accum_model: Callable[[], torch.nn.Module] = const(None) 37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
39 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
@@ -620,10 +619,8 @@ def train(
620 kwargs.update(extra) 619 kwargs.update(extra)
621 620
622 vae.to(accelerator.device, dtype=dtype) 621 vae.to(accelerator.device, dtype=dtype)
623 622 vae.requires_grad_(False)
624 for model in (unet, text_encoder, vae): 623 vae.eval()
625 model.requires_grad_(False)
626 model.eval()
627 624
628 callbacks = strategy.callbacks( 625 callbacks = strategy.callbacks(
629 accelerator=accelerator, 626 accelerator=accelerator,
@@ -636,8 +633,6 @@ def train(
636 **kwargs, 633 **kwargs,
637 ) 634 )
638 635
639 callbacks.on_prepare()
640
641 loss_step_ = partial( 636 loss_step_ = partial(
642 loss_step, 637 loss_step,
643 vae, 638 vae,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 28fccff..9808027 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks(
74 power=ema_power, 74 power=ema_power,
75 max_value=ema_max_decay, 75 max_value=ema_max_decay,
76 ) 76 )
77 ema_unet.to(accelerator.device)
77 else: 78 else:
78 ema_unet = None 79 ema_unet = None
79 80
@@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks(
86 def on_accum_model(): 87 def on_accum_model():
87 return unet 88 return unet
88 89
89 def on_prepare():
90 unet.requires_grad_(True)
91 text_encoder.text_model.encoder.requires_grad_(True)
92 text_encoder.text_model.final_layer_norm.requires_grad_(True)
93
94 if ema_unet is not None:
95 ema_unet.to(accelerator.device)
96
97 @contextmanager 90 @contextmanager
98 def on_train(epoch: int): 91 def on_train(epoch: int):
99 tokenizer.train() 92 tokenizer.train()
@@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks(
181 torch.cuda.empty_cache() 174 torch.cuda.empty_cache()
182 175
183 return TrainingCallbacks( 176 return TrainingCallbacks(
184 on_prepare=on_prepare,
185 on_accum_model=on_accum_model, 177 on_accum_model=on_accum_model,
186 on_train=on_train, 178 on_train=on_train,
187 on_eval=on_eval, 179 on_eval=on_eval,
@@ -203,7 +195,12 @@ def dreambooth_prepare(
203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 195 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
204 **kwargs 196 **kwargs
205): 197):
206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) 198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
199 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
200
201 text_encoder.text_model.embeddings.requires_grad_(False)
202
203 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
207 204
208 205
209dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1c8fad6..3971eae 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -10,18 +10,12 @@ from torch.utils.data import DataLoader
10from accelerate import Accelerator 10from accelerate import Accelerator
11from transformers import CLIPTextModel 11from transformers import CLIPTextModel
12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
13from peft import LoraConfig, LoraModel, get_peft_model_state_dict 13from peft import get_peft_model_state_dict
14from peft.tuners.lora import mark_only_lora_as_trainable
15 14
16from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
18 17
19 18
20# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
21UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
22TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
23
24
25def lora_strategy_callbacks( 19def lora_strategy_callbacks(
26 accelerator: Accelerator, 20 accelerator: Accelerator,
27 unet: UNet2DConditionModel, 21 unet: UNet2DConditionModel,
@@ -61,10 +55,6 @@ def lora_strategy_callbacks(
61 image_size=sample_image_size, 55 image_size=sample_image_size,
62 ) 56 )
63 57
64 def on_prepare():
65 mark_only_lora_as_trainable(unet.model, unet.peft_config.bias)
66 mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias)
67
68 def on_accum_model(): 58 def on_accum_model():
69 return unet 59 return unet
70 60
@@ -93,15 +83,15 @@ def lora_strategy_callbacks(
93 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 83 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
94 84
95 lora_config = {} 85 lora_config = {}
96 state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) 86 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
97 lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) 87 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
98 88
99 text_encoder_state_dict = get_peft_model_state_dict( 89 text_encoder_state_dict = get_peft_model_state_dict(
100 text_encoder, state_dict=accelerator.get_state_dict(text_encoder) 90 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
101 ) 91 )
102 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 92 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
103 state_dict.update(text_encoder_state_dict) 93 state_dict.update(text_encoder_state_dict)
104 lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) 94 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
105 95
106 accelerator.print(state_dict) 96 accelerator.print(state_dict)
107 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") 97 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
@@ -111,11 +101,16 @@ def lora_strategy_callbacks(
111 101
112 @torch.no_grad() 102 @torch.no_grad()
113 def on_sample(step): 103 def on_sample(step):
104 vae_dtype = vae.dtype
105 vae.to(dtype=text_encoder.dtype)
106
114 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 107 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
115 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 108 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
116 109
117 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 110 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
118 111
112 vae.to(dtype=vae_dtype)
113
119 del unet_ 114 del unet_
120 del text_encoder_ 115 del text_encoder_
121 116
@@ -123,7 +118,6 @@ def lora_strategy_callbacks(
123 torch.cuda.empty_cache() 118 torch.cuda.empty_cache()
124 119
125 return TrainingCallbacks( 120 return TrainingCallbacks(
126 on_prepare=on_prepare,
127 on_accum_model=on_accum_model, 121 on_accum_model=on_accum_model,
128 on_train=on_train, 122 on_train=on_train,
129 on_eval=on_eval, 123 on_eval=on_eval,
@@ -147,28 +141,7 @@ def lora_prepare(
147 lora_bias: str = "none", 141 lora_bias: str = "none",
148 **kwargs 142 **kwargs
149): 143):
150 unet_config = LoraConfig( 144 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},)
151 r=lora_rank,
152 lora_alpha=lora_alpha,
153 target_modules=UNET_TARGET_MODULES,
154 lora_dropout=lora_dropout,
155 bias=lora_bias,
156 )
157 unet = LoraModel(unet_config, unet)
158
159 text_encoder_config = LoraConfig(
160 r=lora_rank,
161 lora_alpha=lora_alpha,
162 target_modules=TEXT_ENCODER_TARGET_MODULES,
163 lora_dropout=lora_dropout,
164 bias=lora_bias,
165 )
166 text_encoder = LoraModel(text_encoder_config, text_encoder)
167
168 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
169 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
170
171 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
172 145
173 146
174lora_strategy = TrainingStrategy( 147lora_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 2038e34..10bc6d7 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks(
78 power=ema_power, 78 power=ema_power,
79 max_value=ema_max_decay, 79 max_value=ema_max_decay,
80 ) 80 )
81 ema_embeddings.to(accelerator.device)
81 else: 82 else:
82 ema_embeddings = None 83 ema_embeddings = None
83 84
@@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks(
92 def on_accum_model(): 93 def on_accum_model():
93 return text_encoder.text_model.embeddings.temp_token_embedding 94 return text_encoder.text_model.embeddings.temp_token_embedding
94 95
95 def on_prepare():
96 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
97
98 if ema_embeddings is not None:
99 ema_embeddings.to(accelerator.device)
100
101 if gradient_checkpointing:
102 unet.train()
103
104 @contextmanager 96 @contextmanager
105 def on_train(epoch: int): 97 def on_train(epoch: int):
106 tokenizer.train() 98 tokenizer.train()
@@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks(
177 torch.cuda.empty_cache() 169 torch.cuda.empty_cache()
178 170
179 return TrainingCallbacks( 171 return TrainingCallbacks(
180 on_prepare=on_prepare,
181 on_accum_model=on_accum_model, 172 on_accum_model=on_accum_model,
182 on_train=on_train, 173 on_train=on_train,
183 on_eval=on_eval, 174 on_eval=on_eval,
@@ -197,6 +188,7 @@ def textual_inversion_prepare(
197 train_dataloader: DataLoader, 188 train_dataloader: DataLoader,
198 val_dataloader: Optional[DataLoader], 189 val_dataloader: Optional[DataLoader],
199 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 190 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
191 gradient_checkpointing: bool = False,
200 **kwargs 192 **kwargs
201): 193):
202 weight_dtype = torch.float32 194 weight_dtype = torch.float32
@@ -207,7 +199,17 @@ def textual_inversion_prepare(
207 199
208 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 200 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
209 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 201 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler)
202
210 unet.to(accelerator.device, dtype=weight_dtype) 203 unet.to(accelerator.device, dtype=weight_dtype)
204 unet.requires_grad_(False)
205 unet.eval()
206 if gradient_checkpointing:
207 unet.train()
208
209 text_encoder.text_model.encoder.requires_grad_(False)
210 text_encoder.text_model.final_layer_norm.requires_grad_(False)
211 text_encoder.eval()
212
211 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 213 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
212 214
213 215