summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-20 12:29:58 +0200
committerVolpeon <git@volpeon.ink>2023-04-20 12:29:58 +0200
commit050dcbde0483c277cd632e05d8a0f73c87332785 (patch)
tree4bd86b58929f564201f96bbdf73263bfb559648f
parentFix (diff)
downloadtextual-inversion-diff-050dcbde0483c277cd632e05d8a0f73c87332785.tar.gz
textual-inversion-diff-050dcbde0483c277cd632e05d8a0f73c87332785.tar.bz2
textual-inversion-diff-050dcbde0483c277cd632e05d8a0f73c87332785.zip
Update
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py2
-rw-r--r--train_lora.py145
-rw-r--r--train_ti.py13
3 files changed, 126 insertions, 34 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a0dff54..aa3dbc6 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
301 init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 301 init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
302 302
303 t_start = max(num_inference_steps - init_timestep, 0) 303 t_start = max(num_inference_steps - init_timestep, 0)
304 timesteps = self.scheduler.timesteps[t_start:] 304 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
305 305
306 timesteps = timesteps.to(device) 306 timesteps = timesteps.to(device)
307 307
diff --git a/train_lora.py b/train_lora.py
index d0313fe..0ae8b31 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -303,6 +303,11 @@ def parse_args():
303 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", 303 help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
304 ) 304 )
305 parser.add_argument( 305 parser.add_argument(
306 "--train_text_encoder_cycles",
307 default=999999,
308 help="Number of epochs the text encoder will be trained."
309 )
310 parser.add_argument(
306 "--find_lr", 311 "--find_lr",
307 action="store_true", 312 action="store_true",
308 help="Automatically find a learning rate (no training).", 313 help="Automatically find a learning rate (no training).",
@@ -919,6 +924,78 @@ def main():
919 mid_point=args.lr_mid_point, 924 mid_point=args.lr_mid_point,
920 ) 925 )
921 926
927 # PTI
928 # --------------------------------------------------------------------------------
929
930 if len(args.placeholder_tokens) != 0:
931 filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens]
932
933 pti_datamodule = create_datamodule(
934 batch_size=args.train_batch_size,
935 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
936 )
937 pti_datamodule.setup()
938
939 num_train_epochs = args.num_train_epochs
940 pti_sample_frequency = args.sample_frequency
941 if num_train_epochs is None:
942 num_train_epochs = math.ceil(
943 args.num_train_steps / len(pti_datamodule.train_dataset)
944 ) * args.gradient_accumulation_steps
945 pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps))
946 num_training_steps_per_epoch = math.ceil(
947 len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps)
948 num_train_steps = num_training_steps_per_epoch * num_train_epochs
949 if args.sample_num is not None:
950 pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
951
952 pti_project = "pti"
953
954 if accelerator.is_main_process:
955 accelerator.init_trackers(pti_project)
956
957 pti_sample_output_dir = output_dir / pti_project / "samples"
958
959 print("")
960 print(f"============ PTI ============")
961 print("")
962
963 pti_optimizer = create_optimizer([{
964 "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
965 "lr": args.learning_rate_emb,
966 "weight_decay": 0,
967 }])
968
969 pti_lr_scheduler = create_lr_scheduler(
970 "constant_with_warmup",
971 gradient_accumulation_steps=args.gradient_accumulation_steps,
972 optimizer=pti_optimizer,
973 num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
974 train_epochs=num_train_epochs,
975 warmup_epochs=math.ceil(0.1 * num_train_epochs),
976 )
977
978 pti_checkpoint_output_dir = output_dir / pti_project / "model"
979
980 trainer(
981 strategy=lora_strategy,
982 train_dataloader=pti_datamodule.train_dataloader,
983 val_dataloader=pti_datamodule.val_dataloader,
984 optimizer=pti_optimizer,
985 lr_scheduler=pti_lr_scheduler,
986 num_train_epochs=num_train_epochs,
987 gradient_accumulation_steps=args.gradient_accumulation_steps,
988 cycle=1,
989 pti_mode=True,
990 # --
991 group_labels=["emb"],
992 sample_output_dir=pti_sample_output_dir,
993 checkpoint_output_dir=pti_checkpoint_output_dir,
994 sample_frequency=pti_sample_frequency,
995 )
996
997 embeddings.persist()
998
922 # LORA 999 # LORA
923 # -------------------------------------------------------------------------------- 1000 # --------------------------------------------------------------------------------
924 1001
@@ -941,16 +1018,6 @@ def main():
941 if args.sample_num is not None: 1018 if args.sample_num is not None:
942 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 1019 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
943 1020
944 group_labels = []
945 if len(args.placeholder_tokens) != 0:
946 group_labels.append("emb")
947 group_labels += ["unet", "text"]
948
949 training_iter = 0
950 learning_rate_emb = args.learning_rate_emb
951 learning_rate_unet = args.learning_rate_unet
952 learning_rate_text = args.learning_rate_text
953
954 lora_project = "lora" 1021 lora_project = "lora"
955 1022
956 if accelerator.is_main_process: 1023 if accelerator.is_main_process:
@@ -958,7 +1025,11 @@ def main():
958 1025
959 lora_sample_output_dir = output_dir / lora_project / "samples" 1026 lora_sample_output_dir = output_dir / lora_project / "samples"
960 1027
1028 training_iter = 0
961 auto_cycles = list(args.auto_cycles) 1029 auto_cycles = list(args.auto_cycles)
1030 learning_rate_emb = args.learning_rate_emb
1031 learning_rate_unet = args.learning_rate_unet
1032 learning_rate_text = args.learning_rate_text
962 lr_scheduler = args.lr_scheduler 1033 lr_scheduler = args.lr_scheduler
963 lr_warmup_epochs = args.lr_warmup_epochs 1034 lr_warmup_epochs = args.lr_warmup_epochs
964 lr_cycles = args.lr_cycles 1035 lr_cycles = args.lr_cycles
@@ -970,6 +1041,18 @@ def main():
970 response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 1041 response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
971 1042
972 if response.lower().strip() == "o": 1043 if response.lower().strip() == "o":
1044 if args.learning_rate_emb is not None:
1045 learning_rate_emb = args.learning_rate_emb * 2
1046 if args.learning_rate_unet is not None:
1047 learning_rate_unet = args.learning_rate_unet * 2
1048 if args.learning_rate_text is not None:
1049 learning_rate_text = args.learning_rate_text * 2
1050 else:
1051 learning_rate_emb = args.learning_rate_emb
1052 learning_rate_unet = args.learning_rate_unet
1053 learning_rate_text = args.learning_rate_text
1054
1055 if response.lower().strip() == "o":
973 lr_scheduler = "one_cycle" 1056 lr_scheduler = "one_cycle"
974 lr_warmup_epochs = args.lr_warmup_epochs 1057 lr_warmup_epochs = args.lr_warmup_epochs
975 lr_cycles = args.lr_cycles 1058 lr_cycles = args.lr_cycles
@@ -986,28 +1069,32 @@ def main():
986 break 1069 break
987 1070
988 print("") 1071 print("")
989 print(f"============ LoRA cycle {training_iter + 1} ============") 1072 print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
990 print("") 1073 print("")
991 1074
992 params_to_optimize = [] 1075 params_to_optimize = []
1076 group_labels = []
1077
1078 params_to_optimize.append({
1079 "params": (
1080 param
1081 for param in unet.parameters()
1082 if param.requires_grad
1083 ),
1084 "lr": learning_rate_unet,
1085 })
1086 group_labels.append("unet")
1087
1088 if training_iter < args.train_text_encoder_cycles:
1089 # if len(args.placeholder_tokens) != 0:
1090 # params_to_optimize.append({
1091 # "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
1092 # "lr": learning_rate_emb,
1093 # "weight_decay": 0,
1094 # })
1095 # group_labels.append("emb")
993 1096
994 if len(args.placeholder_tokens) != 0:
995 params_to_optimize.append({ 1097 params_to_optimize.append({
996 "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
997 "lr": learning_rate_emb,
998 "weight_decay": 0,
999 })
1000 group_labels.append("emb")
1001 params_to_optimize += [
1002 {
1003 "params": (
1004 param
1005 for param in unet.parameters()
1006 if param.requires_grad
1007 ),
1008 "lr": learning_rate_unet,
1009 },
1010 {
1011 "params": ( 1098 "params": (
1012 param 1099 param
1013 for param in itertools.chain( 1100 for param in itertools.chain(
@@ -1017,8 +1104,8 @@ def main():
1017 if param.requires_grad 1104 if param.requires_grad
1018 ), 1105 ),
1019 "lr": learning_rate_text, 1106 "lr": learning_rate_text,
1020 }, 1107 })
1021 ] 1108 group_labels.append("text")
1022 1109
1023 lora_optimizer = create_optimizer(params_to_optimize) 1110 lora_optimizer = create_optimizer(params_to_optimize)
1024 1111
diff --git a/train_ti.py b/train_ti.py
index b00b0d7..84ca296 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -906,9 +906,6 @@ def main():
906 if args.sample_num is not None: 906 if args.sample_num is not None:
907 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 907 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
908 908
909 training_iter = 0
910 learning_rate = args.learning_rate
911
912 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" 909 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
913 910
914 if accelerator.is_main_process: 911 if accelerator.is_main_process:
@@ -916,7 +913,9 @@ def main():
916 913
917 sample_output_dir = output_dir / project / "samples" 914 sample_output_dir = output_dir / project / "samples"
918 915
916 training_iter = 0
919 auto_cycles = list(args.auto_cycles) 917 auto_cycles = list(args.auto_cycles)
918 learning_rate = args.learning_rate
920 lr_scheduler = args.lr_scheduler 919 lr_scheduler = args.lr_scheduler
921 lr_warmup_epochs = args.lr_warmup_epochs 920 lr_warmup_epochs = args.lr_warmup_epochs
922 lr_cycles = args.lr_cycles 921 lr_cycles = args.lr_cycles
@@ -929,6 +928,12 @@ def main():
929 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 928 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
930 929
931 if response.lower().strip() == "o": 930 if response.lower().strip() == "o":
931 if args.learning_rate is not None:
932 learning_rate = args.learning_rate * 2
933 else:
934 learning_rate = args.learning_rate
935
936 if response.lower().strip() == "o":
932 lr_scheduler = "one_cycle" 937 lr_scheduler = "one_cycle"
933 lr_warmup_epochs = args.lr_warmup_epochs 938 lr_warmup_epochs = args.lr_warmup_epochs
934 lr_cycles = args.lr_cycles 939 lr_cycles = args.lr_cycles
@@ -945,7 +950,7 @@ def main():
945 break 950 break
946 951
947 print("") 952 print("")
948 print(f"------------ TI cycle {training_iter + 1} ------------") 953 print(f"------------ TI cycle {training_iter + 1}: {response} ------------")
949 print("") 954 print("")
950 955
951 optimizer = create_optimizer( 956 optimizer = create_optimizer(