diff options
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 145 | ||||
| -rw-r--r-- | train_ti.py | 13 |
3 files changed, 126 insertions, 34 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a0dff54..aa3dbc6 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
| 302 | 302 | ||
| 303 | t_start = max(num_inference_steps - init_timestep, 0) | 303 | t_start = max(num_inference_steps - init_timestep, 0) |
| 304 | timesteps = self.scheduler.timesteps[t_start:] | 304 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] |
| 305 | 305 | ||
| 306 | timesteps = timesteps.to(device) | 306 | timesteps = timesteps.to(device) |
| 307 | 307 | ||
diff --git a/train_lora.py b/train_lora.py index d0313fe..0ae8b31 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -303,6 +303,11 @@ def parse_args(): | |||
| 303 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", | 303 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", |
| 304 | ) | 304 | ) |
| 305 | parser.add_argument( | 305 | parser.add_argument( |
| 306 | "--train_text_encoder_cycles", | ||
| 307 | default=999999, | ||
| 308 | help="Number of epochs the text encoder will be trained." | ||
| 309 | ) | ||
| 310 | parser.add_argument( | ||
| 306 | "--find_lr", | 311 | "--find_lr", |
| 307 | action="store_true", | 312 | action="store_true", |
| 308 | help="Automatically find a learning rate (no training).", | 313 | help="Automatically find a learning rate (no training).", |
| @@ -919,6 +924,78 @@ def main(): | |||
| 919 | mid_point=args.lr_mid_point, | 924 | mid_point=args.lr_mid_point, |
| 920 | ) | 925 | ) |
| 921 | 926 | ||
| 927 | # PTI | ||
| 928 | # -------------------------------------------------------------------------------- | ||
| 929 | |||
| 930 | if len(args.placeholder_tokens) != 0: | ||
| 931 | filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens] | ||
| 932 | |||
| 933 | pti_datamodule = create_datamodule( | ||
| 934 | batch_size=args.train_batch_size, | ||
| 935 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | ||
| 936 | ) | ||
| 937 | pti_datamodule.setup() | ||
| 938 | |||
| 939 | num_train_epochs = args.num_train_epochs | ||
| 940 | pti_sample_frequency = args.sample_frequency | ||
| 941 | if num_train_epochs is None: | ||
| 942 | num_train_epochs = math.ceil( | ||
| 943 | args.num_train_steps / len(pti_datamodule.train_dataset) | ||
| 944 | ) * args.gradient_accumulation_steps | ||
| 945 | pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) | ||
| 946 | num_training_steps_per_epoch = math.ceil( | ||
| 947 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) | ||
| 948 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
| 949 | if args.sample_num is not None: | ||
| 950 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
| 951 | |||
| 952 | pti_project = "pti" | ||
| 953 | |||
| 954 | if accelerator.is_main_process: | ||
| 955 | accelerator.init_trackers(pti_project) | ||
| 956 | |||
| 957 | pti_sample_output_dir = output_dir / pti_project / "samples" | ||
| 958 | |||
| 959 | print("") | ||
| 960 | print(f"============ PTI ============") | ||
| 961 | print("") | ||
| 962 | |||
| 963 | pti_optimizer = create_optimizer([{ | ||
| 964 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 965 | "lr": args.learning_rate_emb, | ||
| 966 | "weight_decay": 0, | ||
| 967 | }]) | ||
| 968 | |||
| 969 | pti_lr_scheduler = create_lr_scheduler( | ||
| 970 | "constant_with_warmup", | ||
| 971 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 972 | optimizer=pti_optimizer, | ||
| 973 | num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), | ||
| 974 | train_epochs=num_train_epochs, | ||
| 975 | warmup_epochs=math.ceil(0.1 * num_train_epochs), | ||
| 976 | ) | ||
| 977 | |||
| 978 | pti_checkpoint_output_dir = output_dir / pti_project / "model" | ||
| 979 | |||
| 980 | trainer( | ||
| 981 | strategy=lora_strategy, | ||
| 982 | train_dataloader=pti_datamodule.train_dataloader, | ||
| 983 | val_dataloader=pti_datamodule.val_dataloader, | ||
| 984 | optimizer=pti_optimizer, | ||
| 985 | lr_scheduler=pti_lr_scheduler, | ||
| 986 | num_train_epochs=num_train_epochs, | ||
| 987 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 988 | cycle=1, | ||
| 989 | pti_mode=True, | ||
| 990 | # -- | ||
| 991 | group_labels=["emb"], | ||
| 992 | sample_output_dir=pti_sample_output_dir, | ||
| 993 | checkpoint_output_dir=pti_checkpoint_output_dir, | ||
| 994 | sample_frequency=pti_sample_frequency, | ||
| 995 | ) | ||
| 996 | |||
| 997 | embeddings.persist() | ||
| 998 | |||
| 922 | # LORA | 999 | # LORA |
| 923 | # -------------------------------------------------------------------------------- | 1000 | # -------------------------------------------------------------------------------- |
| 924 | 1001 | ||
| @@ -941,16 +1018,6 @@ def main(): | |||
| 941 | if args.sample_num is not None: | 1018 | if args.sample_num is not None: |
| 942 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1019 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 943 | 1020 | ||
| 944 | group_labels = [] | ||
| 945 | if len(args.placeholder_tokens) != 0: | ||
| 946 | group_labels.append("emb") | ||
| 947 | group_labels += ["unet", "text"] | ||
| 948 | |||
| 949 | training_iter = 0 | ||
| 950 | learning_rate_emb = args.learning_rate_emb | ||
| 951 | learning_rate_unet = args.learning_rate_unet | ||
| 952 | learning_rate_text = args.learning_rate_text | ||
| 953 | |||
| 954 | lora_project = "lora" | 1021 | lora_project = "lora" |
| 955 | 1022 | ||
| 956 | if accelerator.is_main_process: | 1023 | if accelerator.is_main_process: |
| @@ -958,7 +1025,11 @@ def main(): | |||
| 958 | 1025 | ||
| 959 | lora_sample_output_dir = output_dir / lora_project / "samples" | 1026 | lora_sample_output_dir = output_dir / lora_project / "samples" |
| 960 | 1027 | ||
| 1028 | training_iter = 0 | ||
| 961 | auto_cycles = list(args.auto_cycles) | 1029 | auto_cycles = list(args.auto_cycles) |
| 1030 | learning_rate_emb = args.learning_rate_emb | ||
| 1031 | learning_rate_unet = args.learning_rate_unet | ||
| 1032 | learning_rate_text = args.learning_rate_text | ||
| 962 | lr_scheduler = args.lr_scheduler | 1033 | lr_scheduler = args.lr_scheduler |
| 963 | lr_warmup_epochs = args.lr_warmup_epochs | 1034 | lr_warmup_epochs = args.lr_warmup_epochs |
| 964 | lr_cycles = args.lr_cycles | 1035 | lr_cycles = args.lr_cycles |
| @@ -970,6 +1041,18 @@ def main(): | |||
| 970 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1041 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
| 971 | 1042 | ||
| 972 | if response.lower().strip() == "o": | 1043 | if response.lower().strip() == "o": |
| 1044 | if args.learning_rate_emb is not None: | ||
| 1045 | learning_rate_emb = args.learning_rate_emb * 2 | ||
| 1046 | if args.learning_rate_unet is not None: | ||
| 1047 | learning_rate_unet = args.learning_rate_unet * 2 | ||
| 1048 | if args.learning_rate_text is not None: | ||
| 1049 | learning_rate_text = args.learning_rate_text * 2 | ||
| 1050 | else: | ||
| 1051 | learning_rate_emb = args.learning_rate_emb | ||
| 1052 | learning_rate_unet = args.learning_rate_unet | ||
| 1053 | learning_rate_text = args.learning_rate_text | ||
| 1054 | |||
| 1055 | if response.lower().strip() == "o": | ||
| 973 | lr_scheduler = "one_cycle" | 1056 | lr_scheduler = "one_cycle" |
| 974 | lr_warmup_epochs = args.lr_warmup_epochs | 1057 | lr_warmup_epochs = args.lr_warmup_epochs |
| 975 | lr_cycles = args.lr_cycles | 1058 | lr_cycles = args.lr_cycles |
| @@ -986,28 +1069,32 @@ def main(): | |||
| 986 | break | 1069 | break |
| 987 | 1070 | ||
| 988 | print("") | 1071 | print("") |
| 989 | print(f"============ LoRA cycle {training_iter + 1} ============") | 1072 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
| 990 | print("") | 1073 | print("") |
| 991 | 1074 | ||
| 992 | params_to_optimize = [] | 1075 | params_to_optimize = [] |
| 1076 | group_labels = [] | ||
| 1077 | |||
| 1078 | params_to_optimize.append({ | ||
| 1079 | "params": ( | ||
| 1080 | param | ||
| 1081 | for param in unet.parameters() | ||
| 1082 | if param.requires_grad | ||
| 1083 | ), | ||
| 1084 | "lr": learning_rate_unet, | ||
| 1085 | }) | ||
| 1086 | group_labels.append("unet") | ||
| 1087 | |||
| 1088 | if training_iter < args.train_text_encoder_cycles: | ||
| 1089 | # if len(args.placeholder_tokens) != 0: | ||
| 1090 | # params_to_optimize.append({ | ||
| 1091 | # "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 1092 | # "lr": learning_rate_emb, | ||
| 1093 | # "weight_decay": 0, | ||
| 1094 | # }) | ||
| 1095 | # group_labels.append("emb") | ||
| 993 | 1096 | ||
| 994 | if len(args.placeholder_tokens) != 0: | ||
| 995 | params_to_optimize.append({ | 1097 | params_to_optimize.append({ |
| 996 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 997 | "lr": learning_rate_emb, | ||
| 998 | "weight_decay": 0, | ||
| 999 | }) | ||
| 1000 | group_labels.append("emb") | ||
| 1001 | params_to_optimize += [ | ||
| 1002 | { | ||
| 1003 | "params": ( | ||
| 1004 | param | ||
| 1005 | for param in unet.parameters() | ||
| 1006 | if param.requires_grad | ||
| 1007 | ), | ||
| 1008 | "lr": learning_rate_unet, | ||
| 1009 | }, | ||
| 1010 | { | ||
| 1011 | "params": ( | 1098 | "params": ( |
| 1012 | param | 1099 | param |
| 1013 | for param in itertools.chain( | 1100 | for param in itertools.chain( |
| @@ -1017,8 +1104,8 @@ def main(): | |||
| 1017 | if param.requires_grad | 1104 | if param.requires_grad |
| 1018 | ), | 1105 | ), |
| 1019 | "lr": learning_rate_text, | 1106 | "lr": learning_rate_text, |
| 1020 | }, | 1107 | }) |
| 1021 | ] | 1108 | group_labels.append("text") |
| 1022 | 1109 | ||
| 1023 | lora_optimizer = create_optimizer(params_to_optimize) | 1110 | lora_optimizer = create_optimizer(params_to_optimize) |
| 1024 | 1111 | ||
diff --git a/train_ti.py b/train_ti.py index b00b0d7..84ca296 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -906,9 +906,6 @@ def main(): | |||
| 906 | if args.sample_num is not None: | 906 | if args.sample_num is not None: |
| 907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 907 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| 908 | 908 | ||
| 909 | training_iter = 0 | ||
| 910 | learning_rate = args.learning_rate | ||
| 911 | |||
| 912 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" | 909 | project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" |
| 913 | 910 | ||
| 914 | if accelerator.is_main_process: | 911 | if accelerator.is_main_process: |
| @@ -916,7 +913,9 @@ def main(): | |||
| 916 | 913 | ||
| 917 | sample_output_dir = output_dir / project / "samples" | 914 | sample_output_dir = output_dir / project / "samples" |
| 918 | 915 | ||
| 916 | training_iter = 0 | ||
| 919 | auto_cycles = list(args.auto_cycles) | 917 | auto_cycles = list(args.auto_cycles) |
| 918 | learning_rate = args.learning_rate | ||
| 920 | lr_scheduler = args.lr_scheduler | 919 | lr_scheduler = args.lr_scheduler |
| 921 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
| 922 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
| @@ -929,6 +928,12 @@ def main(): | |||
| 929 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 928 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
| 930 | 929 | ||
| 931 | if response.lower().strip() == "o": | 930 | if response.lower().strip() == "o": |
| 931 | if args.learning_rate is not None: | ||
| 932 | learning_rate = args.learning_rate * 2 | ||
| 933 | else: | ||
| 934 | learning_rate = args.learning_rate | ||
| 935 | |||
| 936 | if response.lower().strip() == "o": | ||
| 932 | lr_scheduler = "one_cycle" | 937 | lr_scheduler = "one_cycle" |
| 933 | lr_warmup_epochs = args.lr_warmup_epochs | 938 | lr_warmup_epochs = args.lr_warmup_epochs |
| 934 | lr_cycles = args.lr_cycles | 939 | lr_cycles = args.lr_cycles |
| @@ -945,7 +950,7 @@ def main(): | |||
| 945 | break | 950 | break |
| 946 | 951 | ||
| 947 | print("") | 952 | print("") |
| 948 | print(f"------------ TI cycle {training_iter + 1} ------------") | 953 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
| 949 | print("") | 954 | print("") |
| 950 | 955 | ||
| 951 | optimizer = create_optimizer( | 956 | optimizer = create_optimizer( |
