summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
committerVolpeon <git@volpeon.ink>2023-03-24 10:53:16 +0100
commit95adaea8b55d8e3755c035758bc649ae22548572 (patch)
tree80239f0bc55b99615718a935be2caa2e1e68e20a
parentBring back Perlin offset noise (diff)
downloadtextual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.gz
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.tar.bz2
textual-inversion-diff-95adaea8b55d8e3755c035758bc649ae22548572.zip
Refactoring, fixed Lora training
-rw-r--r--train_lora.py73
-rw-r--r--training/functional.py9
-rw-r--r--training/strategy/dreambooth.py17
-rw-r--r--training/strategy/lora.py49
-rw-r--r--training/strategy/ti.py22
5 files changed, 104 insertions, 66 deletions
diff --git a/train_lora.py b/train_lora.py
index 8dd3c86..fa24cee 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -11,6 +11,7 @@ import torch.utils.checkpoint
11from accelerate import Accelerator 11from accelerate import Accelerator
12from accelerate.logging import get_logger 12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from peft import LoraConfig, LoraModel
14from slugify import slugify 15from slugify import slugify
15 16
16from util.files import load_config, load_embeddings_from_dir 17from util.files import load_config, load_embeddings_from_dir
@@ -21,6 +22,11 @@ from training.strategy.lora import lora_strategy
21from training.optimization import get_scheduler 22from training.optimization import get_scheduler
22from training.util import save_args 23from training.util import save_args
23 24
25# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
26UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
27TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
28
29
24logger = get_logger(__name__) 30logger = get_logger(__name__)
25 31
26 32
@@ -176,6 +182,54 @@ def parse_args():
176 help="Number of updates steps to accumulate before performing a backward/update pass.", 182 help="Number of updates steps to accumulate before performing a backward/update pass.",
177 ) 183 )
178 parser.add_argument( 184 parser.add_argument(
185 "--lora_r",
186 type=int,
187 default=8,
188 help="Lora rank, only used if use_lora is True"
189 )
190 parser.add_argument(
191 "--lora_alpha",
192 type=int,
193 default=32,
194 help="Lora alpha, only used if use_lora is True"
195 )
196 parser.add_argument(
197 "--lora_dropout",
198 type=float,
199 default=0.0,
200 help="Lora dropout, only used if use_lora is True"
201 )
202 parser.add_argument(
203 "--lora_bias",
204 type=str,
205 default="none",
206 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
207 )
208 parser.add_argument(
209 "--lora_text_encoder_r",
210 type=int,
211 default=8,
212 help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
213 )
214 parser.add_argument(
215 "--lora_text_encoder_alpha",
216 type=int,
217 default=32,
218 help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
219 )
220 parser.add_argument(
221 "--lora_text_encoder_dropout",
222 type=float,
223 default=0.0,
224 help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
225 )
226 parser.add_argument(
227 "--lora_text_encoder_bias",
228 type=str,
229 default="none",
230 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
231 )
232 parser.add_argument(
179 "--find_lr", 233 "--find_lr",
180 action="store_true", 234 action="store_true",
181 help="Automatically find a learning rate (no training).", 235 help="Automatically find a learning rate (no training).",
@@ -424,13 +478,30 @@ def main():
424 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 478 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
425 args.pretrained_model_name_or_path) 479 args.pretrained_model_name_or_path)
426 480
481 unet_config = LoraConfig(
482 r=args.lora_r,
483 lora_alpha=args.lora_alpha,
484 target_modules=UNET_TARGET_MODULES,
485 lora_dropout=args.lora_dropout,
486 bias=args.lora_bias,
487 )
488 unet = LoraModel(unet_config, unet)
489
490 text_encoder_config = LoraConfig(
491 r=args.lora_text_encoder_r,
492 lora_alpha=args.lora_text_encoder_alpha,
493 target_modules=TEXT_ENCODER_TARGET_MODULES,
494 lora_dropout=args.lora_text_encoder_dropout,
495 bias=args.lora_text_encoder_bias,
496 )
497 text_encoder = LoraModel(text_encoder_config, text_encoder)
498
427 vae.enable_slicing() 499 vae.enable_slicing()
428 vae.set_use_memory_efficient_attention_xformers(True) 500 vae.set_use_memory_efficient_attention_xformers(True)
429 unet.enable_xformers_memory_efficient_attention() 501 unet.enable_xformers_memory_efficient_attention()
430 502
431 if args.gradient_checkpointing: 503 if args.gradient_checkpointing:
432 unet.enable_gradient_checkpointing() 504 unet.enable_gradient_checkpointing()
433 text_encoder.gradient_checkpointing_enable()
434 505
435 if args.embeddings_dir is not None: 506 if args.embeddings_dir is not None:
436 embeddings_dir = Path(args.embeddings_dir) 507 embeddings_dir = Path(args.embeddings_dir)
diff --git a/training/functional.py b/training/functional.py
index a5b339d..ee73ab2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,6 @@ def const(result=None):
34 34
35@dataclass 35@dataclass
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_prepare: Callable[[], None] = const()
38 on_accum_model: Callable[[], torch.nn.Module] = const(None) 37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
39 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
40 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
@@ -620,10 +619,8 @@ def train(
620 kwargs.update(extra) 619 kwargs.update(extra)
621 620
622 vae.to(accelerator.device, dtype=dtype) 621 vae.to(accelerator.device, dtype=dtype)
623 622 vae.requires_grad_(False)
624 for model in (unet, text_encoder, vae): 623 vae.eval()
625 model.requires_grad_(False)
626 model.eval()
627 624
628 callbacks = strategy.callbacks( 625 callbacks = strategy.callbacks(
629 accelerator=accelerator, 626 accelerator=accelerator,
@@ -636,8 +633,6 @@ def train(
636 **kwargs, 633 **kwargs,
637 ) 634 )
638 635
639 callbacks.on_prepare()
640
641 loss_step_ = partial( 636 loss_step_ = partial(
642 loss_step, 637 loss_step,
643 vae, 638 vae,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 28fccff..9808027 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -74,6 +74,7 @@ def dreambooth_strategy_callbacks(
74 power=ema_power, 74 power=ema_power,
75 max_value=ema_max_decay, 75 max_value=ema_max_decay,
76 ) 76 )
77 ema_unet.to(accelerator.device)
77 else: 78 else:
78 ema_unet = None 79 ema_unet = None
79 80
@@ -86,14 +87,6 @@ def dreambooth_strategy_callbacks(
86 def on_accum_model(): 87 def on_accum_model():
87 return unet 88 return unet
88 89
89 def on_prepare():
90 unet.requires_grad_(True)
91 text_encoder.text_model.encoder.requires_grad_(True)
92 text_encoder.text_model.final_layer_norm.requires_grad_(True)
93
94 if ema_unet is not None:
95 ema_unet.to(accelerator.device)
96
97 @contextmanager 90 @contextmanager
98 def on_train(epoch: int): 91 def on_train(epoch: int):
99 tokenizer.train() 92 tokenizer.train()
@@ -181,7 +174,6 @@ def dreambooth_strategy_callbacks(
181 torch.cuda.empty_cache() 174 torch.cuda.empty_cache()
182 175
183 return TrainingCallbacks( 176 return TrainingCallbacks(
184 on_prepare=on_prepare,
185 on_accum_model=on_accum_model, 177 on_accum_model=on_accum_model,
186 on_train=on_train, 178 on_train=on_train,
187 on_eval=on_eval, 179 on_eval=on_eval,
@@ -203,7 +195,12 @@ def dreambooth_prepare(
203 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 195 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
204 **kwargs 196 **kwargs
205): 197):
206 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) 198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
199 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
200
201 text_encoder.text_model.embeddings.requires_grad_(False)
202
203 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
207 204
208 205
209dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1c8fad6..3971eae 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -10,18 +10,12 @@ from torch.utils.data import DataLoader
10from accelerate import Accelerator 10from accelerate import Accelerator
11from transformers import CLIPTextModel 11from transformers import CLIPTextModel
12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
13from peft import LoraConfig, LoraModel, get_peft_model_state_dict 13from peft import get_peft_model_state_dict
14from peft.tuners.lora import mark_only_lora_as_trainable
15 14
16from models.clip.tokenizer import MultiCLIPTokenizer 15from models.clip.tokenizer import MultiCLIPTokenizer
17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 16from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
18 17
19 18
20# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
21UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
22TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
23
24
25def lora_strategy_callbacks( 19def lora_strategy_callbacks(
26 accelerator: Accelerator, 20 accelerator: Accelerator,
27 unet: UNet2DConditionModel, 21 unet: UNet2DConditionModel,
@@ -61,10 +55,6 @@ def lora_strategy_callbacks(
61 image_size=sample_image_size, 55 image_size=sample_image_size,
62 ) 56 )
63 57
64 def on_prepare():
65 mark_only_lora_as_trainable(unet.model, unet.peft_config.bias)
66 mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias)
67
68 def on_accum_model(): 58 def on_accum_model():
69 return unet 59 return unet
70 60
@@ -93,15 +83,15 @@ def lora_strategy_callbacks(
93 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 83 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
94 84
95 lora_config = {} 85 lora_config = {}
96 state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) 86 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
97 lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True) 87 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
98 88
99 text_encoder_state_dict = get_peft_model_state_dict( 89 text_encoder_state_dict = get_peft_model_state_dict(
100 text_encoder, state_dict=accelerator.get_state_dict(text_encoder) 90 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
101 ) 91 )
102 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 92 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
103 state_dict.update(text_encoder_state_dict) 93 state_dict.update(text_encoder_state_dict)
104 lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True) 94 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
105 95
106 accelerator.print(state_dict) 96 accelerator.print(state_dict)
107 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt") 97 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
@@ -111,11 +101,16 @@ def lora_strategy_callbacks(
111 101
112 @torch.no_grad() 102 @torch.no_grad()
113 def on_sample(step): 103 def on_sample(step):
104 vae_dtype = vae.dtype
105 vae.to(dtype=text_encoder.dtype)
106
114 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 107 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
115 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 108 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
116 109
117 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 110 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
118 111
112 vae.to(dtype=vae_dtype)
113
119 del unet_ 114 del unet_
120 del text_encoder_ 115 del text_encoder_
121 116
@@ -123,7 +118,6 @@ def lora_strategy_callbacks(
123 torch.cuda.empty_cache() 118 torch.cuda.empty_cache()
124 119
125 return TrainingCallbacks( 120 return TrainingCallbacks(
126 on_prepare=on_prepare,
127 on_accum_model=on_accum_model, 121 on_accum_model=on_accum_model,
128 on_train=on_train, 122 on_train=on_train,
129 on_eval=on_eval, 123 on_eval=on_eval,
@@ -147,28 +141,7 @@ def lora_prepare(
147 lora_bias: str = "none", 141 lora_bias: str = "none",
148 **kwargs 142 **kwargs
149): 143):
150 unet_config = LoraConfig( 144 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},)
151 r=lora_rank,
152 lora_alpha=lora_alpha,
153 target_modules=UNET_TARGET_MODULES,
154 lora_dropout=lora_dropout,
155 bias=lora_bias,
156 )
157 unet = LoraModel(unet_config, unet)
158
159 text_encoder_config = LoraConfig(
160 r=lora_rank,
161 lora_alpha=lora_alpha,
162 target_modules=TEXT_ENCODER_TARGET_MODULES,
163 lora_dropout=lora_dropout,
164 bias=lora_bias,
165 )
166 text_encoder = LoraModel(text_encoder_config, text_encoder)
167
168 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
169 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
170
171 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
172 145
173 146
174lora_strategy = TrainingStrategy( 147lora_strategy = TrainingStrategy(
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 2038e34..10bc6d7 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -78,6 +78,7 @@ def textual_inversion_strategy_callbacks(
78 power=ema_power, 78 power=ema_power,
79 max_value=ema_max_decay, 79 max_value=ema_max_decay,
80 ) 80 )
81 ema_embeddings.to(accelerator.device)
81 else: 82 else:
82 ema_embeddings = None 83 ema_embeddings = None
83 84
@@ -92,15 +93,6 @@ def textual_inversion_strategy_callbacks(
92 def on_accum_model(): 93 def on_accum_model():
93 return text_encoder.text_model.embeddings.temp_token_embedding 94 return text_encoder.text_model.embeddings.temp_token_embedding
94 95
95 def on_prepare():
96 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
97
98 if ema_embeddings is not None:
99 ema_embeddings.to(accelerator.device)
100
101 if gradient_checkpointing:
102 unet.train()
103
104 @contextmanager 96 @contextmanager
105 def on_train(epoch: int): 97 def on_train(epoch: int):
106 tokenizer.train() 98 tokenizer.train()
@@ -177,7 +169,6 @@ def textual_inversion_strategy_callbacks(
177 torch.cuda.empty_cache() 169 torch.cuda.empty_cache()
178 170
179 return TrainingCallbacks( 171 return TrainingCallbacks(
180 on_prepare=on_prepare,
181 on_accum_model=on_accum_model, 172 on_accum_model=on_accum_model,
182 on_train=on_train, 173 on_train=on_train,
183 on_eval=on_eval, 174 on_eval=on_eval,
@@ -197,6 +188,7 @@ def textual_inversion_prepare(
197 train_dataloader: DataLoader, 188 train_dataloader: DataLoader,
198 val_dataloader: Optional[DataLoader], 189 val_dataloader: Optional[DataLoader],
199 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 190 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
191 gradient_checkpointing: bool = False,
200 **kwargs 192 **kwargs
201): 193):
202 weight_dtype = torch.float32 194 weight_dtype = torch.float32
@@ -207,7 +199,17 @@ def textual_inversion_prepare(
207 199
208 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 200 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
209 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 201 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler)
202
210 unet.to(accelerator.device, dtype=weight_dtype) 203 unet.to(accelerator.device, dtype=weight_dtype)
204 unet.requires_grad_(False)
205 unet.eval()
206 if gradient_checkpointing:
207 unet.train()
208
209 text_encoder.text_model.encoder.requires_grad_(False)
210 text_encoder.text_model.final_layer_norm.requires_grad_(False)
211 text_encoder.eval()
212
211 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} 213 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
212 214
213 215