diff options
| -rw-r--r-- | train_lora.py | 37 | ||||
| -rw-r--r-- | training/strategy/lora.py | 35 |
2 files changed, 39 insertions, 33 deletions
diff --git a/train_lora.py b/train_lora.py index 476efcf..5b0a292 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -21,7 +21,6 @@ from data.csv import VlpnDataModule, keyword_filter | |||
| 21 | from training.functional import train, add_placeholder_tokens, get_models | 21 | from training.functional import train, add_placeholder_tokens, get_models |
| 22 | from training.lr import plot_metrics | 22 | from training.lr import plot_metrics |
| 23 | from training.strategy.lora import lora_strategy | 23 | from training.strategy.lora import lora_strategy |
| 24 | from training.strategy.ti import textual_inversion_strategy | ||
| 25 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
| 26 | from training.util import save_args | 25 | from training.util import save_args |
| 27 | 26 | ||
| @@ -829,6 +828,12 @@ def main(): | |||
| 829 | sample_num_batches=args.sample_batches, | 828 | sample_num_batches=args.sample_batches, |
| 830 | sample_num_steps=args.sample_steps, | 829 | sample_num_steps=args.sample_steps, |
| 831 | sample_image_size=args.sample_image_size, | 830 | sample_image_size=args.sample_image_size, |
| 831 | placeholder_tokens=args.placeholder_tokens, | ||
| 832 | placeholder_token_ids=placeholder_token_ids, | ||
| 833 | use_emb_decay=args.use_emb_decay, | ||
| 834 | emb_decay_target=args.emb_decay_target, | ||
| 835 | emb_decay=args.emb_decay, | ||
| 836 | max_grad_norm=args.max_grad_norm, | ||
| 832 | ) | 837 | ) |
| 833 | 838 | ||
| 834 | create_datamodule = partial( | 839 | create_datamodule = partial( |
| @@ -907,7 +912,8 @@ def main(): | |||
| 907 | ) | 912 | ) |
| 908 | 913 | ||
| 909 | metrics = trainer( | 914 | metrics = trainer( |
| 910 | strategy=textual_inversion_strategy, | 915 | strategy=lora_strategy, |
| 916 | pti_mode=True, | ||
| 911 | project="pti", | 917 | project="pti", |
| 912 | train_dataloader=pti_datamodule.train_dataloader, | 918 | train_dataloader=pti_datamodule.train_dataloader, |
| 913 | val_dataloader=pti_datamodule.val_dataloader, | 919 | val_dataloader=pti_datamodule.val_dataloader, |
| @@ -919,11 +925,6 @@ def main(): | |||
| 919 | sample_output_dir=pti_sample_output_dir, | 925 | sample_output_dir=pti_sample_output_dir, |
| 920 | checkpoint_output_dir=pti_checkpoint_output_dir, | 926 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 921 | sample_frequency=math.inf, | 927 | sample_frequency=math.inf, |
| 922 | placeholder_tokens=args.placeholder_tokens, | ||
| 923 | placeholder_token_ids=placeholder_token_ids, | ||
| 924 | use_emb_decay=args.use_emb_decay, | ||
| 925 | emb_decay_target=args.emb_decay_target, | ||
| 926 | emb_decay=args.emb_decay, | ||
| 927 | ) | 928 | ) |
| 928 | 929 | ||
| 929 | plot_metrics(metrics, pti_output_dir / "lr.png") | 930 | plot_metrics(metrics, pti_output_dir / "lr.png") |
| @@ -952,13 +953,21 @@ def main(): | |||
| 952 | lora_optimizer = create_optimizer( | 953 | lora_optimizer = create_optimizer( |
| 953 | [ | 954 | [ |
| 954 | { | 955 | { |
| 955 | "params": unet.parameters(), | 956 | "params": ( |
| 957 | param | ||
| 958 | for param in unet.parameters() | ||
| 959 | if param.requires_grad | ||
| 960 | ), | ||
| 956 | "lr": args.learning_rate_unet, | 961 | "lr": args.learning_rate_unet, |
| 957 | }, | 962 | }, |
| 958 | { | 963 | { |
| 959 | "params": itertools.chain( | 964 | "params": ( |
| 960 | text_encoder.text_model.encoder.parameters(), | 965 | param |
| 961 | text_encoder.text_model.final_layer_norm.parameters(), | 966 | for param in itertools.chain( |
| 967 | text_encoder.text_model.encoder.parameters(), | ||
| 968 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 969 | ) | ||
| 970 | if param.requires_grad | ||
| 962 | ), | 971 | ), |
| 963 | "lr": args.learning_rate_text, | 972 | "lr": args.learning_rate_text, |
| 964 | }, | 973 | }, |
| @@ -990,12 +999,6 @@ def main(): | |||
| 990 | sample_output_dir=lora_sample_output_dir, | 999 | sample_output_dir=lora_sample_output_dir, |
| 991 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1000 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 992 | sample_frequency=lora_sample_frequency, | 1001 | sample_frequency=lora_sample_frequency, |
| 993 | placeholder_tokens=args.placeholder_tokens, | ||
| 994 | placeholder_token_ids=placeholder_token_ids, | ||
| 995 | use_emb_decay=args.use_emb_decay, | ||
| 996 | emb_decay_target=args.emb_decay_target, | ||
| 997 | emb_decay=args.emb_decay, | ||
| 998 | max_grad_norm=args.max_grad_norm, | ||
| 999 | ) | 1002 | ) |
| 1000 | 1003 | ||
| 1001 | plot_metrics(metrics, lora_output_dir / "lr.png") | 1004 | plot_metrics(metrics, lora_output_dir / "lr.png") |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 80ffa9c..912ff26 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -34,6 +34,7 @@ def lora_strategy_callbacks( | |||
| 34 | seed: int, | 34 | seed: int, |
| 35 | placeholder_tokens: list[str], | 35 | placeholder_tokens: list[str], |
| 36 | placeholder_token_ids: list[list[int]], | 36 | placeholder_token_ids: list[list[int]], |
| 37 | pti_mode: bool = False, | ||
| 37 | use_emb_decay: bool = False, | 38 | use_emb_decay: bool = False, |
| 38 | emb_decay_target: float = 0.4, | 39 | emb_decay_target: float = 0.4, |
| 39 | emb_decay: float = 1e-2, | 40 | emb_decay: float = 1e-2, |
| @@ -79,10 +80,11 @@ def lora_strategy_callbacks( | |||
| 79 | yield | 80 | yield |
| 80 | 81 | ||
| 81 | def on_before_optimize(lr: float, epoch: int): | 82 | def on_before_optimize(lr: float, epoch: int): |
| 82 | accelerator.clip_grad_norm_( | 83 | if not pti_mode: |
| 83 | itertools.chain(unet.parameters(), text_encoder.parameters()), | 84 | accelerator.clip_grad_norm_( |
| 84 | max_grad_norm | 85 | itertools.chain(unet.parameters(), text_encoder.parameters()), |
| 85 | ) | 86 | max_grad_norm |
| 87 | ) | ||
| 86 | 88 | ||
| 87 | if use_emb_decay: | 89 | if use_emb_decay: |
| 88 | params = [ | 90 | params = [ |
| @@ -117,20 +119,21 @@ def lora_strategy_callbacks( | |||
| 117 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 119 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
| 118 | ) | 120 | ) |
| 119 | 121 | ||
| 120 | lora_config = {} | 122 | if not pti_mode: |
| 121 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 123 | lora_config = {} |
| 122 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 124 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) |
| 125 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | ||
| 123 | 126 | ||
| 124 | text_encoder_state_dict = get_peft_model_state_dict( | 127 | text_encoder_state_dict = get_peft_model_state_dict( |
| 125 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 128 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
| 126 | ) | 129 | ) |
| 127 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 130 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} |
| 128 | state_dict.update(text_encoder_state_dict) | 131 | state_dict.update(text_encoder_state_dict) |
| 129 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 132 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
| 130 | 133 | ||
| 131 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 134 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
| 132 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 135 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
| 133 | json.dump(lora_config, f) | 136 | json.dump(lora_config, f) |
| 134 | 137 | ||
| 135 | del unet_ | 138 | del unet_ |
| 136 | del text_encoder_ | 139 | del text_encoder_ |
