summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py37
-rw-r--r--training/strategy/lora.py35
2 files changed, 39 insertions, 33 deletions
diff --git a/train_lora.py b/train_lora.py
index 476efcf..5b0a292 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -21,7 +21,6 @@ from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 21from training.functional import train, add_placeholder_tokens, get_models
22from training.lr import plot_metrics 22from training.lr import plot_metrics
23from training.strategy.lora import lora_strategy 23from training.strategy.lora import lora_strategy
24from training.strategy.ti import textual_inversion_strategy
25from training.optimization import get_scheduler 24from training.optimization import get_scheduler
26from training.util import save_args 25from training.util import save_args
27 26
@@ -829,6 +828,12 @@ def main():
829 sample_num_batches=args.sample_batches, 828 sample_num_batches=args.sample_batches,
830 sample_num_steps=args.sample_steps, 829 sample_num_steps=args.sample_steps,
831 sample_image_size=args.sample_image_size, 830 sample_image_size=args.sample_image_size,
831 placeholder_tokens=args.placeholder_tokens,
832 placeholder_token_ids=placeholder_token_ids,
833 use_emb_decay=args.use_emb_decay,
834 emb_decay_target=args.emb_decay_target,
835 emb_decay=args.emb_decay,
836 max_grad_norm=args.max_grad_norm,
832 ) 837 )
833 838
834 create_datamodule = partial( 839 create_datamodule = partial(
@@ -907,7 +912,8 @@ def main():
907 ) 912 )
908 913
909 metrics = trainer( 914 metrics = trainer(
910 strategy=textual_inversion_strategy, 915 strategy=lora_strategy,
916 pti_mode=True,
911 project="pti", 917 project="pti",
912 train_dataloader=pti_datamodule.train_dataloader, 918 train_dataloader=pti_datamodule.train_dataloader,
913 val_dataloader=pti_datamodule.val_dataloader, 919 val_dataloader=pti_datamodule.val_dataloader,
@@ -919,11 +925,6 @@ def main():
919 sample_output_dir=pti_sample_output_dir, 925 sample_output_dir=pti_sample_output_dir,
920 checkpoint_output_dir=pti_checkpoint_output_dir, 926 checkpoint_output_dir=pti_checkpoint_output_dir,
921 sample_frequency=math.inf, 927 sample_frequency=math.inf,
922 placeholder_tokens=args.placeholder_tokens,
923 placeholder_token_ids=placeholder_token_ids,
924 use_emb_decay=args.use_emb_decay,
925 emb_decay_target=args.emb_decay_target,
926 emb_decay=args.emb_decay,
927 ) 928 )
928 929
929 plot_metrics(metrics, pti_output_dir / "lr.png") 930 plot_metrics(metrics, pti_output_dir / "lr.png")
@@ -952,13 +953,21 @@ def main():
952 lora_optimizer = create_optimizer( 953 lora_optimizer = create_optimizer(
953 [ 954 [
954 { 955 {
955 "params": unet.parameters(), 956 "params": (
957 param
958 for param in unet.parameters()
959 if param.requires_grad
960 ),
956 "lr": args.learning_rate_unet, 961 "lr": args.learning_rate_unet,
957 }, 962 },
958 { 963 {
959 "params": itertools.chain( 964 "params": (
960 text_encoder.text_model.encoder.parameters(), 965 param
961 text_encoder.text_model.final_layer_norm.parameters(), 966 for param in itertools.chain(
967 text_encoder.text_model.encoder.parameters(),
968 text_encoder.text_model.final_layer_norm.parameters(),
969 )
970 if param.requires_grad
962 ), 971 ),
963 "lr": args.learning_rate_text, 972 "lr": args.learning_rate_text,
964 }, 973 },
@@ -990,12 +999,6 @@ def main():
990 sample_output_dir=lora_sample_output_dir, 999 sample_output_dir=lora_sample_output_dir,
991 checkpoint_output_dir=lora_checkpoint_output_dir, 1000 checkpoint_output_dir=lora_checkpoint_output_dir,
992 sample_frequency=lora_sample_frequency, 1001 sample_frequency=lora_sample_frequency,
993 placeholder_tokens=args.placeholder_tokens,
994 placeholder_token_ids=placeholder_token_ids,
995 use_emb_decay=args.use_emb_decay,
996 emb_decay_target=args.emb_decay_target,
997 emb_decay=args.emb_decay,
998 max_grad_norm=args.max_grad_norm,
999 ) 1002 )
1000 1003
1001 plot_metrics(metrics, lora_output_dir / "lr.png") 1004 plot_metrics(metrics, lora_output_dir / "lr.png")
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 80ffa9c..912ff26 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -34,6 +34,7 @@ def lora_strategy_callbacks(
34 seed: int, 34 seed: int,
35 placeholder_tokens: list[str], 35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]], 36 placeholder_token_ids: list[list[int]],
37 pti_mode: bool = False,
37 use_emb_decay: bool = False, 38 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4, 39 emb_decay_target: float = 0.4,
39 emb_decay: float = 1e-2, 40 emb_decay: float = 1e-2,
@@ -79,10 +80,11 @@ def lora_strategy_callbacks(
79 yield 80 yield
80 81
81 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(lr: float, epoch: int):
82 accelerator.clip_grad_norm_( 83 if not pti_mode:
83 itertools.chain(unet.parameters(), text_encoder.parameters()), 84 accelerator.clip_grad_norm_(
84 max_grad_norm 85 itertools.chain(unet.parameters(), text_encoder.parameters()),
85 ) 86 max_grad_norm
87 )
86 88
87 if use_emb_decay: 89 if use_emb_decay:
88 params = [ 90 params = [
@@ -117,20 +119,21 @@ def lora_strategy_callbacks(
117 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 119 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
118 ) 120 )
119 121
120 lora_config = {} 122 if not pti_mode:
121 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 123 lora_config = {}
122 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 124 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_))
125 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
123 126
124 text_encoder_state_dict = get_peft_model_state_dict( 127 text_encoder_state_dict = get_peft_model_state_dict(
125 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) 128 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
126 ) 129 )
127 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 130 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
128 state_dict.update(text_encoder_state_dict) 131 state_dict.update(text_encoder_state_dict)
129 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 132 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
130 133
131 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 134 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
132 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 135 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
133 json.dump(lora_config, f) 136 json.dump(lora_config, f)
134 137
135 del unet_ 138 del unet_
136 del text_encoder_ 139 del text_encoder_