diff options
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
-rw-r--r-- | train_lora.py | 145 | ||||
-rw-r--r-- | train_ti.py | 13 |
3 files changed, 126 insertions, 34 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a0dff54..aa3dbc6 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
302 | 302 | ||
303 | t_start = max(num_inference_steps - init_timestep, 0) | 303 | t_start = max(num_inference_steps - init_timestep, 0) |
304 | timesteps = self.scheduler.timesteps[t_start:] | 304 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] |
305 | 305 | ||
306 | timesteps = timesteps.to(device) | 306 | timesteps = timesteps.to(device) |
307 | 307 | ||
diff --git a/train_lora.py b/train_lora.py index d0313fe..0ae8b31 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -303,6 +303,11 @@ def parse_args(): | |||
303 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | 303 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", |
304 | ) | 304 | ) |
305 | parser.add_argument( | 305 | parser.add_argument( |
306 | "--train_text_encoder_cycles", | ||
307 | default=999999, | ||
308 | help="Number of epochs the text encoder will be trained." | ||
309 | ) | ||
310 | parser.add_argument( | ||
306 | "--find_lr", | 311 | "--find_lr", |
307 | action="store_true", | 312 | action="store_true", |
308 | help="Automatically find a learning rate (no training).", | 313 | help="Automatically find a learning rate (no training).", |
@@ -919,6 +924,78 @@ def main(): | |||
919 | mid_point=args.lr_mid_point, | 924 | mid_point=args.lr_mid_point, |
920 | ) | 925 | ) |
921 | 926 | ||
927 | # PTI | ||
928 | # -------------------------------------------------------------------------------- | ||
929 | |||
930 | if len(args.placeholder_tokens) != 0: | ||
931 | filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] | ||
932 | |||
933 | pti_datamodule = create_datamodule( | ||
934 | batch_size=args.train_batch_size, | ||
935 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | ||
936 | ) | ||
937 | pti_datamodule.setup() | ||
938 | |||
939 | num_train_epochs = args.num_train_epochs | ||
940 | pti_sample_frequency = args.sample_frequency | ||
941 | if num_train_epochs is None: | ||
942 | num_train_epochs = math.ceil( | ||
943 | args.num_train_steps / len(pti_datamodule.train_dataset) | ||
944 | ) * args.gradient_accumulation_steps | ||
945 | pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) | ||
946 | num_training_steps_per_epoch = math.ceil( | ||
947 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
948 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
949 | if args.sample_num is not None: | ||
950 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
951 | |||
952 | pti_project = "pti" | ||
953 | |||
954 | if accelerator.is_main_process: | ||
955 | accelerator.init_trackers(pti_project) | ||
956 | |||
957 | pti_sample_output_dir = output_dir / pti_project / "samples" | ||
958 | |||
959 | print("") | ||
960 | print(f"============ PTI ============") | ||
961 | print("") | ||
962 | |||
963 | pti_optimizer = create_optimizer([{ | ||
964 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
965 | "lr": args.learning_rate_emb, | ||
966 | "weight_decay": 0, | ||
967 | }]) | ||
968 | |||
969 | pti_lr_scheduler = create_lr_scheduler( | ||
970 | "constant_with_warmup", | ||
971 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
972 | optimizer=pti_optimizer, | ||
973 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
974 | train_epochs=num_train_epochs, | ||
975 | warmup_epochs=math.ceil(0.1 * num_train_epochs), | ||
976 | ) | ||
977 | |||
978 | pti_checkpoint_output_dir = output_dir / pti_project / "model" | ||
979 | |||
980 | trainer( | ||
981 | strategy=lora_strategy, | ||
982 | train_dataloader=pti_datamodule.train_dataloader, | ||
983 | val_dataloader=pti_datamodule.val_dataloader, | ||
984 | optimizer=pti_optimizer, | ||
985 | lr_scheduler=pti_lr_scheduler, | ||
986 | num_train_epochs=num_train_epochs, | ||
987 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
988 | cycle=1, | ||
989 | pti_mode=True, | ||
990 | # -- | ||
991 | group_labels=["emb"], | ||
992 | sample_output_dir=pti_sample_output_dir, | ||
993 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
994 | sample_frequency=pti_sample_frequency, | ||
995 | ) | ||
996 | |||
997 | embeddings.persist() | ||
998 | |||
922 | # LORA | 999 | # LORA |
923 | # -------------------------------------------------------------------------------- | 1000 | # -------------------------------------------------------------------------------- |
924 | 1001 | ||
@@ -941,16 +1018,6 @@ def main(): | |||
941 | if args.sample_num is not None: | 1018 | if args.sample_num is not None: |
942 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1019 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
943 | 1020 | ||
944 | group_labels = [] | ||
945 | if len(args.placeholder_tokens) != 0: | ||
946 | group_labels.append("emb") | ||
947 | group_labels += ["unet", "text"] | ||
948 | |||
949 | training_iter = 0 | ||
950 | learning_rate_emb = args.learning_rate_emb | ||
951 | learning_rate_unet = args.learning_rate_unet | ||
952 | learning_rate_text = args.learning_rate_text | ||
953 | |||
954 | lora_project = "lora" | 1021 | lora_project = "lora" |
955 | 1022 | ||
956 | if accelerator.is_main_process: | 1023 | if accelerator.is_main_process: |
@@ -958,7 +1025,11 @@ def main(): | |||
958 | 1025 | ||
959 | lora_sample_output_dir = output_dir / lora_project / "samples" | 1026 | lora_sample_output_dir = output_dir / lora_project / "samples" |
960 | 1027 | ||
1028 | training_iter = 0 | ||
961 | auto_cycles = list(args.auto_cycles) | 1029 | auto_cycles = list(args.auto_cycles) |
1030 | learning_rate_emb = args.learning_rate_emb | ||
1031 | learning_rate_unet = args.learning_rate_unet | ||
1032 | learning_rate_text = args.learning_rate_text | ||
962 | lr_scheduler = args.lr_scheduler | 1033 | lr_scheduler = args.lr_scheduler |
963 | lr_warmup_epochs = args.lr_warmup_epochs | 1034 | lr_warmup_epochs = args.lr_warmup_epochs |
964 | lr_cycles = args.lr_cycles | 1035 | lr_cycles = args.lr_cycles |
@@ -970,6 +1041,18 @@ def main(): | |||
970 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1041 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
971 | 1042 | ||
972 | if response.lower().strip() == "o": | 1043 | if response.lower().strip() == "o": |
1044 | if args.learning_rate_emb is not None: | ||
1045 | learning_rate_emb = args.learning_rate_emb * 2 | ||
1046 | if args.learning_rate_unet is not None: | ||
1047 | learning_rate_unet = args.learning_rate_unet * 2 | ||
1048 | if args.learning_rate_text is not None: | ||
1049 | learning_rate_text = args.learning_rate_text * 2 | ||
1050 | else: | ||
1051 | learning_rate_emb = args.learning_rate_emb | ||
1052 | learning_rate_unet = args.learning_rate_unet | ||
1053 | learning_rate_text = args.learning_rate_text | ||
1054 | |||
1055 | if response.lower().strip() == "o": | ||
973 | lr_scheduler = "one_cycle" | 1056 | lr_scheduler = "one_cycle" |
974 | lr_warmup_epochs = args.lr_warmup_epochs | 1057 | lr_warmup_epochs = args.lr_warmup_epochs |
975 | lr_cycles = args.lr_cycles | 1058 | lr_cycles = args.lr_cycles |
@@ -986,28 +1069,32 @@ def main(): | |||
986 | break | 1069 | break |
987 | 1070 | ||
988 | print("") | 1071 | print("") |
989 | print(f"============ LoRA cycle {training_iter + 1} ============") | 1072 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
990 | print("") | 1073 | print("") |
991 | 1074 | ||
992 | params_to_optimize = [] | 1075 | params_to_optimize = [] |
1076 | group_labels = [] | ||
1077 | |||
1078 | params_to_optimize.append({ | ||
1079 | "params": ( | ||
1080 | param | ||
1081 | for param in unet.parameters() | ||
1082 | if param.requires_grad | ||
1083 | ), | ||
1084 | "lr": learning_rate_unet, | ||
1085 | }) | ||
1086 | group_labels.append("unet") | ||
1087 | |||
1088 | if training_iter < args.train_text_encoder_cycles: | ||
1089 | # if len(args.placeholder_tokens) != 0: | ||
1090 | # params_to_optimize.append({ | ||
1091 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
1092 | # "lr": learning_rate_emb, | ||
1093 | # "weight_decay": 0, | ||
1094 | # }) | ||
1095 | # group_labels.append("emb") | ||
993 | 1096 | ||
994 | if len(args.placeholder_tokens) != 0: | ||
995 | params_to_optimize.append({ | 1097 | params_to_optimize.append({ |
996 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
997 | "lr": learning_rate_emb, | ||
998 | "weight_decay": 0, | ||
999 | }) | ||
1000 | group_labels.append("emb") | ||
1001 | params_to_optimize += [ | ||
1002 | { | ||
1003 | "params": ( | ||
1004 | param | ||
1005 | for param in unet.parameters() | ||
1006 | if param.requires_grad | ||
1007 | ), | ||
1008 | "lr": learning_rate_unet, | ||
1009 | }, | ||
1010 | { | ||
1011 | "params": ( | 1098 | "params": ( |
1012 | param | 1099 | param |
1013 | for param in itertools.chain( | 1100 | for param in itertools.chain( |
@@ -1017,8 +1104,8 @@ def main(): | |||
1017 | if param.requires_grad | 1104 | if param.requires_grad |
1018 | ), | 1105 | ), |
1019 | "lr": learning_rate_text, | 1106 | "lr": learning_rate_text, |
1020 | }, | 1107 | }) |
1021 | ] | 1108 | group_labels.append("text") |
1022 | 1109 | ||
1023 | lora_optimizer = create_optimizer(params_to_optimize) | 1110 | lora_optimizer = create_optimizer(params_to_optimize) |
1024 | 1111 | ||
diff --git a/train_ti.py b/train_ti.py index b00b0d7..84ca296 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -906,9 +906,6 @@ def main(): | |||
906 | if args.sample_num is not None: | 906 | if args.sample_num is not None: |
907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
908 | 908 | ||
909 | training_iter = 0 | ||
910 | learning_rate = args.learning_rate | ||
911 | |||
912 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 909 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" |
913 | 910 | ||
914 | if accelerator.is_main_process: | 911 | if accelerator.is_main_process: |
@@ -916,7 +913,9 @@ def main(): | |||
916 | 913 | ||
917 | sample_output_dir = output_dir / project / "samples" | 914 | sample_output_dir = output_dir / project / "samples" |
918 | 915 | ||
916 | training_iter = 0 | ||
919 | auto_cycles = list(args.auto_cycles) | 917 | auto_cycles = list(args.auto_cycles) |
918 | learning_rate = args.learning_rate | ||
920 | lr_scheduler = args.lr_scheduler | 919 | lr_scheduler = args.lr_scheduler |
921 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
922 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
@@ -929,6 +928,12 @@ def main(): | |||
929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
930 | 929 | ||
931 | if response.lower().strip() == "o": | 930 | if response.lower().strip() == "o": |
931 | if args.learning_rate is not None: | ||
932 | learning_rate = args.learning_rate * 2 | ||
933 | else: | ||
934 | learning_rate = args.learning_rate | ||
935 | |||
936 | if response.lower().strip() == "o": | ||
932 | lr_scheduler = "one_cycle" | 937 | lr_scheduler = "one_cycle" |
933 | lr_warmup_epochs = args.lr_warmup_epochs | 938 | lr_warmup_epochs = args.lr_warmup_epochs |
934 | lr_cycles = args.lr_cycles | 939 | lr_cycles = args.lr_cycles |
@@ -945,7 +950,7 @@ def main(): | |||
945 | break | 950 | break |
946 | 951 | ||
947 | print("") | 952 | print("") |
948 | print(f"------------ TI cycle {training_iter + 1} ------------") | 953 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
949 | print("") | 954 | print("") |
950 | 955 | ||
951 | optimizer = create_optimizer( | 956 | optimizer = create_optimizer( |