summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 19:03:25 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 19:03:25 +0200
commit71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch)
tree29c704ca549a4c4323403b6cbb0e62f54040ae22
parentAdded option to use constant LR on cycles > 1 (diff)
downloadtextual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz
textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2
textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip
Improved automation caps
-rw-r--r--train_lora.py53
-rw-r--r--train_ti.py53
-rw-r--r--training/functional.py17
-rw-r--r--training/strategy/dreambooth.py4
-rw-r--r--training/strategy/lora.py4
-rw-r--r--training/strategy/ti.py23
6 files changed, 100 insertions, 54 deletions
diff --git a/train_lora.py b/train_lora.py
index 4d4c16a..ba5aee1 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -84,9 +84,9 @@ def parse_args():
84 ) 84 )
85 parser.add_argument( 85 parser.add_argument(
86 "--auto_cycles", 86 "--auto_cycles",
87 type=int, 87 type=str,
88 default=1, 88 default="o",
89 help="How many cycles to run automatically." 89 help="Cycles to run automatically."
90 ) 90 )
91 parser.add_argument( 91 parser.add_argument(
92 "--cycle_decay", 92 "--cycle_decay",
@@ -95,11 +95,6 @@ def parse_args():
95 help="Learning rate decay per cycle." 95 help="Learning rate decay per cycle."
96 ) 96 )
97 parser.add_argument( 97 parser.add_argument(
98 "--cycle_constant",
99 action="store_true",
100 help="Use constant LR on cycles > 1."
101 )
102 parser.add_argument(
103 "--placeholder_tokens", 98 "--placeholder_tokens",
104 type=str, 99 type=str,
105 nargs='*', 100 nargs='*',
@@ -920,7 +915,6 @@ def main():
920 annealing_func=args.lr_annealing_func, 915 annealing_func=args.lr_annealing_func,
921 warmup_exp=args.lr_warmup_exp, 916 warmup_exp=args.lr_warmup_exp,
922 annealing_exp=args.lr_annealing_exp, 917 annealing_exp=args.lr_annealing_exp,
923 cycles=args.lr_cycles,
924 end_lr=1e2, 918 end_lr=1e2,
925 mid_point=args.lr_mid_point, 919 mid_point=args.lr_mid_point,
926 ) 920 )
@@ -964,20 +958,38 @@ def main():
964 958
965 lora_sample_output_dir = output_dir / lora_project / "samples" 959 lora_sample_output_dir = output_dir / lora_project / "samples"
966 960
961 auto_cycles = list(args.auto_cycles)
962 lr_scheduler = args.lr_scheduler
963 lr_warmup_epochs = args.lr_warmup_epochs
964 lr_cycles = args.lr_cycles
965
967 while True: 966 while True:
968 if training_iter >= args.auto_cycles: 967 if len(auto_cycles) != 0:
969 response = input("Run another cycle? [y/n] ") 968 response = auto_cycles.pop(0)
970 if response.lower().strip() == "n": 969 else:
971 break 970 response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
971
972 if response.lower().strip() == "o":
973 lr_scheduler = "one_cycle"
974 lr_warmup_epochs = args.lr_warmup_epochs
975 lr_cycles = args.lr_cycles
976 if response.lower().strip() == "w":
977 lr_scheduler = "constant"
978 lr_warmup_epochs = num_train_epochs
979 if response.lower().strip() == "c":
980 lr_scheduler = "constant"
981 lr_warmup_epochs = 0
982 if response.lower().strip() == "d":
983 lr_scheduler = "cosine"
984 lr_warmup_epochs = 0
985 lr_cycles = 1
986 elif response.lower().strip() == "s":
987 break
972 988
973 print("") 989 print("")
974 print(f"============ LoRA cycle {training_iter + 1} ============") 990 print(f"============ LoRA cycle {training_iter + 1} ============")
975 print("") 991 print("")
976 992
977 if args.cycle_constant and training_iter == 1:
978 args.lr_scheduler = "constant"
979 args.lr_warmup_epochs = 0
980
981 params_to_optimize = [] 993 params_to_optimize = []
982 994
983 if len(args.placeholder_tokens) != 0: 995 if len(args.placeholder_tokens) != 0:
@@ -1012,12 +1024,13 @@ def main():
1012 lora_optimizer = create_optimizer(params_to_optimize) 1024 lora_optimizer = create_optimizer(params_to_optimize)
1013 1025
1014 lora_lr_scheduler = create_lr_scheduler( 1026 lora_lr_scheduler = create_lr_scheduler(
1015 args.lr_scheduler, 1027 lr_scheduler,
1016 gradient_accumulation_steps=args.gradient_accumulation_steps, 1028 gradient_accumulation_steps=args.gradient_accumulation_steps,
1017 optimizer=lora_optimizer, 1029 optimizer=lora_optimizer,
1018 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), 1030 num_training_steps_per_epoch=len(lora_datamodule.train_dataloader),
1019 train_epochs=num_train_epochs, 1031 train_epochs=num_train_epochs,
1020 warmup_epochs=args.lr_warmup_epochs, 1032 cycles=lr_cycles,
1033 warmup_epochs=lr_warmup_epochs,
1021 ) 1034 )
1022 1035
1023 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" 1036 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}"
@@ -1031,7 +1044,7 @@ def main():
1031 num_train_epochs=num_train_epochs, 1044 num_train_epochs=num_train_epochs,
1032 gradient_accumulation_steps=args.gradient_accumulation_steps, 1045 gradient_accumulation_steps=args.gradient_accumulation_steps,
1033 global_step_offset=training_iter * num_train_steps, 1046 global_step_offset=training_iter * num_train_steps,
1034 initial_samples=training_iter == 0, 1047 cycle=training_iter,
1035 # -- 1048 # --
1036 group_labels=group_labels, 1049 group_labels=group_labels,
1037 sample_output_dir=lora_sample_output_dir, 1050 sample_output_dir=lora_sample_output_dir,
diff --git a/train_ti.py b/train_ti.py
index c452269..880320f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -68,9 +68,9 @@ def parse_args():
68 ) 68 )
69 parser.add_argument( 69 parser.add_argument(
70 "--auto_cycles", 70 "--auto_cycles",
71 type=int, 71 type=str,
72 default=1, 72 default="o",
73 help="How many cycles to run automatically." 73 help="Cycles to run automatically."
74 ) 74 )
75 parser.add_argument( 75 parser.add_argument(
76 "--cycle_decay", 76 "--cycle_decay",
@@ -79,11 +79,6 @@ def parse_args():
79 help="Learning rate decay per cycle." 79 help="Learning rate decay per cycle."
80 ) 80 )
81 parser.add_argument( 81 parser.add_argument(
82 "--cycle_constant",
83 action="store_true",
84 help="Use constant LR on cycles > 1."
85 )
86 parser.add_argument(
87 "--placeholder_tokens", 82 "--placeholder_tokens",
88 type=str, 83 type=str,
89 nargs='*', 84 nargs='*',
@@ -921,27 +916,45 @@ def main():
921 916
922 sample_output_dir = output_dir / project / "samples" 917 sample_output_dir = output_dir / project / "samples"
923 918
919 auto_cycles = list(args.auto_cycles)
920 lr_scheduler = args.lr_scheduler
921 lr_warmup_epochs = args.lr_warmup_epochs
922 lr_cycles = args.lr_cycles
923
924 while True: 924 while True:
925 if training_iter >= args.auto_cycles: 925 if len(auto_cycles) != 0:
926 response = input("Run another cycle? [y/n] ") 926 response = auto_cycles.pop(0)
927 if response.lower().strip() == "n": 927 else:
928 break 928 response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
929
930 if response.lower().strip() == "o":
931 lr_scheduler = "one_cycle"
932 lr_warmup_epochs = args.lr_warmup_epochs
933 lr_cycles = args.lr_cycles
934 if response.lower().strip() == "w":
935 lr_scheduler = "constant"
936 lr_warmup_epochs = num_train_epochs
937 if response.lower().strip() == "c":
938 lr_scheduler = "constant"
939 lr_warmup_epochs = 0
940 if response.lower().strip() == "d":
941 lr_scheduler = "cosine"
942 lr_warmup_epochs = 0
943 lr_cycles = 1
944 elif response.lower().strip() == "s":
945 break
929 946
930 print("") 947 print("")
931 print(f"------------ TI cycle {training_iter + 1} ------------") 948 print(f"------------ TI cycle {training_iter + 1} ------------")
932 print("") 949 print("")
933 950
934 if args.cycle_constant and training_iter == 1:
935 args.lr_scheduler = "constant"
936 args.lr_warmup_epochs = 0
937
938 optimizer = create_optimizer( 951 optimizer = create_optimizer(
939 text_encoder.text_model.embeddings.token_embedding.parameters(), 952 text_encoder.text_model.embeddings.token_embedding.parameters(),
940 lr=learning_rate, 953 lr=learning_rate,
941 ) 954 )
942 955
943 lr_scheduler = get_scheduler( 956 lr_scheduler = get_scheduler(
944 args.lr_scheduler, 957 lr_scheduler,
945 optimizer=optimizer, 958 optimizer=optimizer,
946 num_training_steps_per_epoch=len(datamodule.train_dataloader), 959 num_training_steps_per_epoch=len(datamodule.train_dataloader),
947 gradient_accumulation_steps=args.gradient_accumulation_steps, 960 gradient_accumulation_steps=args.gradient_accumulation_steps,
@@ -950,10 +963,10 @@ def main():
950 annealing_func=args.lr_annealing_func, 963 annealing_func=args.lr_annealing_func,
951 warmup_exp=args.lr_warmup_exp, 964 warmup_exp=args.lr_warmup_exp,
952 annealing_exp=args.lr_annealing_exp, 965 annealing_exp=args.lr_annealing_exp,
953 cycles=args.lr_cycles, 966 cycles=lr_cycles,
954 end_lr=1e3, 967 end_lr=1e3,
955 train_epochs=num_train_epochs, 968 train_epochs=num_train_epochs,
956 warmup_epochs=args.lr_warmup_epochs, 969 warmup_epochs=lr_warmup_epochs,
957 mid_point=args.lr_mid_point, 970 mid_point=args.lr_mid_point,
958 ) 971 )
959 972
@@ -966,7 +979,7 @@ def main():
966 lr_scheduler=lr_scheduler, 979 lr_scheduler=lr_scheduler,
967 num_train_epochs=num_train_epochs, 980 num_train_epochs=num_train_epochs,
968 global_step_offset=training_iter * num_train_steps, 981 global_step_offset=training_iter * num_train_steps,
969 initial_samples=training_iter == 0, 982 cycle=training_iter,
970 # -- 983 # --
971 group_labels=["emb"], 984 group_labels=["emb"],
972 checkpoint_output_dir=checkpoint_output_dir, 985 checkpoint_output_dir=checkpoint_output_dir,
diff --git a/training/functional.py b/training/functional.py
index 2da0f69..ebc40de 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -42,7 +42,7 @@ class TrainingCallbacks():
42 on_after_optimize: Callable[[Any, dict[str, float]], None] = const() 42 on_after_optimize: Callable[[Any, dict[str, float]], None] = const()
43 on_after_epoch: Callable[[], None] = const() 43 on_after_epoch: Callable[[], None] = const()
44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
45 on_sample: Callable[[int], None] = const() 45 on_sample: Callable[[int, int], None] = const()
46 on_checkpoint: Callable[[int, str], None] = const() 46 on_checkpoint: Callable[[int, str], None] = const()
47 47
48 48
@@ -96,6 +96,7 @@ def save_samples(
96 output_dir: Path, 96 output_dir: Path,
97 seed: int, 97 seed: int,
98 step: int, 98 step: int,
99 cycle: int = 1,
99 batch_size: int = 1, 100 batch_size: int = 1,
100 num_batches: int = 1, 101 num_batches: int = 1,
101 num_steps: int = 20, 102 num_steps: int = 20,
@@ -125,7 +126,7 @@ def save_samples(
125 126
126 for pool, data, gen in datasets: 127 for pool, data, gen in datasets:
127 all_samples = [] 128 all_samples = []
128 file_path = output_dir / pool / f"step_{step}.jpg" 129 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg"
129 file_path.parent.mkdir(parents=True, exist_ok=True) 130 file_path.parent.mkdir(parents=True, exist_ok=True)
130 131
131 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 132 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
@@ -455,7 +456,7 @@ def train_loop(
455 sample_frequency: int = 10, 456 sample_frequency: int = 10,
456 checkpoint_frequency: int = 50, 457 checkpoint_frequency: int = 50,
457 milestone_checkpoints: bool = True, 458 milestone_checkpoints: bool = True,
458 initial_samples: bool = True, 459 cycle: int = 1,
459 global_step_offset: int = 0, 460 global_step_offset: int = 0,
460 num_epochs: int = 100, 461 num_epochs: int = 100,
461 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
@@ -518,12 +519,12 @@ def train_loop(
518 try: 519 try:
519 for epoch in range(num_epochs): 520 for epoch in range(num_epochs):
520 if accelerator.is_main_process: 521 if accelerator.is_main_process:
521 if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): 522 if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0):
522 local_progress_bar.clear() 523 local_progress_bar.clear()
523 global_progress_bar.clear() 524 global_progress_bar.clear()
524 525
525 with on_eval(): 526 with on_eval():
526 on_sample(global_step) 527 on_sample(cycle, global_step)
527 528
528 if epoch % checkpoint_frequency == 0 and epoch != 0: 529 if epoch % checkpoint_frequency == 0 and epoch != 0:
529 local_progress_bar.clear() 530 local_progress_bar.clear()
@@ -648,7 +649,7 @@ def train_loop(
648 if accelerator.is_main_process: 649 if accelerator.is_main_process:
649 print("Finished!") 650 print("Finished!")
650 with on_eval(): 651 with on_eval():
651 on_sample(global_step) 652 on_sample(cycle, global_step)
652 on_checkpoint(global_step, "end") 653 on_checkpoint(global_step, "end")
653 654
654 except KeyboardInterrupt: 655 except KeyboardInterrupt:
@@ -680,7 +681,7 @@ def train(
680 sample_frequency: int = 20, 681 sample_frequency: int = 20,
681 checkpoint_frequency: int = 50, 682 checkpoint_frequency: int = 50,
682 milestone_checkpoints: bool = True, 683 milestone_checkpoints: bool = True,
683 initial_samples: bool = True, 684 cycle: int = 1,
684 global_step_offset: int = 0, 685 global_step_offset: int = 0,
685 guidance_scale: float = 0.0, 686 guidance_scale: float = 0.0,
686 prior_loss_weight: float = 1.0, 687 prior_loss_weight: float = 1.0,
@@ -731,7 +732,7 @@ def train(
731 sample_frequency=sample_frequency, 732 sample_frequency=sample_frequency,
732 checkpoint_frequency=checkpoint_frequency, 733 checkpoint_frequency=checkpoint_frequency,
733 milestone_checkpoints=milestone_checkpoints, 734 milestone_checkpoints=milestone_checkpoints,
734 initial_samples=initial_samples, 735 cycle=cycle,
735 global_step_offset=global_step_offset, 736 global_step_offset=global_step_offset,
736 num_epochs=num_train_epochs, 737 num_epochs=num_train_epochs,
737 gradient_accumulation_steps=gradient_accumulation_steps, 738 gradient_accumulation_steps=gradient_accumulation_steps,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 4ae28b7..e6fcc89 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks(
148 torch.cuda.empty_cache() 148 torch.cuda.empty_cache()
149 149
150 @torch.no_grad() 150 @torch.no_grad()
151 def on_sample(step): 151 def on_sample(cycle, step):
152 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 152 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
153 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 153 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
154 154
@@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks(
158 unet_.to(dtype=weight_dtype) 158 unet_.to(dtype=weight_dtype)
159 text_encoder_.to(dtype=weight_dtype) 159 text_encoder_.to(dtype=weight_dtype)
160 160
161 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 161 save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_)
162 162
163 unet_.to(dtype=orig_unet_dtype) 163 unet_.to(dtype=orig_unet_dtype)
164 text_encoder_.to(dtype=orig_text_encoder_dtype) 164 text_encoder_.to(dtype=orig_text_encoder_dtype)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 48236fb..5c3012e 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -146,11 +146,11 @@ def lora_strategy_callbacks(
146 torch.cuda.empty_cache() 146 torch.cuda.empty_cache()
147 147
148 @torch.no_grad() 148 @torch.no_grad()
149 def on_sample(step): 149 def on_sample(cycle, step):
150 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 150 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
151 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 151 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
152 152
153 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 153 save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_)
154 154
155 del unet_, text_encoder_ 155 del unet_, text_encoder_
156 156
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f0b84b5..6bbff64 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,10 +104,28 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int):
108 if use_emb_decay:
109 params = [
110 p
111 for p in text_encoder.text_model.embeddings.token_embedding.parameters()
112 if p.grad is not None
113 ]
114 return torch.stack(params) if len(params) != 0 else None
115
116 @torch.no_grad()
107 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
108 if ema_embeddings is not None: 118 if ema_embeddings is not None:
109 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters())
110 120
121 if use_emb_decay and w is not None:
122 lr = lrs["emb"] or lrs["0"]
123 lambda_ = emb_decay * lr
124
125 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm))
128
111 def on_log(): 129 def on_log():
112 if ema_embeddings is not None: 130 if ema_embeddings is not None:
113 return {"ema_decay": ema_embeddings.decay} 131 return {"ema_decay": ema_embeddings.decay}
@@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks(
125 ) 143 )
126 144
127 @torch.no_grad() 145 @torch.no_grad()
128 def on_sample(step): 146 def on_sample(cycle, step):
129 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) 147 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
130 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) 148 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
131 149
@@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks(
135 unet_.to(dtype=weight_dtype) 153 unet_.to(dtype=weight_dtype)
136 text_encoder_.to(dtype=weight_dtype) 154 text_encoder_.to(dtype=weight_dtype)
137 155
138 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 156 save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_)
139 157
140 unet_.to(dtype=orig_unet_dtype) 158 unet_.to(dtype=orig_unet_dtype)
141 text_encoder_.to(dtype=orig_text_encoder_dtype) 159 text_encoder_.to(dtype=orig_text_encoder_dtype)
@@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks(
148 return TrainingCallbacks( 166 return TrainingCallbacks(
149 on_train=on_train, 167 on_train=on_train,
150 on_eval=on_eval, 168 on_eval=on_eval,
169 on_before_optimize=on_before_optimize,
151 on_after_optimize=on_after_optimize, 170 on_after_optimize=on_after_optimize,
152 on_log=on_log, 171 on_log=on_log,
153 on_checkpoint=on_checkpoint, 172 on_checkpoint=on_checkpoint,