summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py13
-rw-r--r--training/functional.py6
-rw-r--r--training/strategy/dreambooth.py6
-rw-r--r--training/strategy/lora.py9
-rw-r--r--training/strategy/ti.py6
5 files changed, 16 insertions, 24 deletions
diff --git a/train_lora.py b/train_lora.py
index daf1f6c..476efcf 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -548,15 +548,18 @@ def parse_args():
548 if args.project is None: 548 if args.project is None:
549 raise ValueError("You must specify --project") 549 raise ValueError("You must specify --project")
550 550
551 if args.initializer_tokens is None:
552 args.initializer_tokens = []
553
554 if args.placeholder_tokens is None:
555 args.placeholder_tokens = []
556
551 if isinstance(args.placeholder_tokens, str): 557 if isinstance(args.placeholder_tokens, str):
552 args.placeholder_tokens = [args.placeholder_tokens] 558 args.placeholder_tokens = [args.placeholder_tokens]
553 559
554 if isinstance(args.initializer_tokens, str): 560 if isinstance(args.initializer_tokens, str):
555 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 561 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
556 562
557 if len(args.initializer_tokens) == 0:
558 raise ValueError("You must specify --initializer_tokens")
559
560 if len(args.placeholder_tokens) == 0: 563 if len(args.placeholder_tokens) == 0:
561 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 564 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
562 565
@@ -884,7 +887,7 @@ def main():
884 num_pti_epochs = math.ceil( 887 num_pti_epochs = math.ceil(
885 args.num_pti_steps / len(pti_datamodule.train_dataset) 888 args.num_pti_steps / len(pti_datamodule.train_dataset)
886 ) * args.pti_gradient_accumulation_steps 889 ) * args.pti_gradient_accumulation_steps
887 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) 890 pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps))
888 891
889 pti_optimizer = create_optimizer( 892 pti_optimizer = create_optimizer(
890 [ 893 [
@@ -915,7 +918,7 @@ def main():
915 # -- 918 # --
916 sample_output_dir=pti_sample_output_dir, 919 sample_output_dir=pti_sample_output_dir,
917 checkpoint_output_dir=pti_checkpoint_output_dir, 920 checkpoint_output_dir=pti_checkpoint_output_dir,
918 sample_frequency=pti_sample_frequency, 921 sample_frequency=math.inf,
919 placeholder_tokens=args.placeholder_tokens, 922 placeholder_tokens=args.placeholder_tokens,
920 placeholder_token_ids=placeholder_token_ids, 923 placeholder_token_ids=placeholder_token_ids,
921 use_emb_decay=args.use_emb_decay, 924 use_emb_decay=args.use_emb_decay,
diff --git a/training/functional.py b/training/functional.py
index c30d1c0..4d83df1 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,6 @@ def const(result=None):
34 34
35@dataclass 35@dataclass
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[float, int], Any] = const() 39 on_before_optimize: Callable[[float, int], Any] = const()
@@ -461,7 +460,6 @@ def train_loop(
461 ) 460 )
462 global_progress_bar.set_description("Total progress") 461 global_progress_bar.set_description("Total progress")
463 462
464 model = callbacks.on_accum_model()
465 on_log = callbacks.on_log 463 on_log = callbacks.on_log
466 on_train = callbacks.on_train 464 on_train = callbacks.on_train
467 on_before_optimize = callbacks.on_before_optimize 465 on_before_optimize = callbacks.on_before_optimize
@@ -498,8 +496,6 @@ def train_loop(
498 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 496 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
499 local_progress_bar.reset() 497 local_progress_bar.reset()
500 498
501 model.train()
502
503 with on_train(epoch): 499 with on_train(epoch):
504 for step, batch in enumerate(train_dataloader): 500 for step, batch in enumerate(train_dataloader):
505 loss, acc, bsz = loss_step(step, batch, cache) 501 loss, acc, bsz = loss_step(step, batch, cache)
@@ -560,8 +556,6 @@ def train_loop(
560 on_after_epoch() 556 on_after_epoch()
561 557
562 if val_dataloader is not None: 558 if val_dataloader is not None:
563 model.eval()
564
565 cur_loss_val = AverageMeter() 559 cur_loss_val = AverageMeter()
566 cur_acc_val = AverageMeter() 560 cur_acc_val = AverageMeter()
567 561
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 9808027..0286673 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks(
84 else: 84 else:
85 return nullcontext() 85 return nullcontext()
86 86
87 def on_accum_model():
88 return unet
89
90 @contextmanager 87 @contextmanager
91 def on_train(epoch: int): 88 def on_train(epoch: int):
89 unet.train()
92 tokenizer.train() 90 tokenizer.train()
93 91
94 if epoch < train_text_encoder_epochs: 92 if epoch < train_text_encoder_epochs:
@@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks(
101 99
102 @contextmanager 100 @contextmanager
103 def on_eval(): 101 def on_eval():
102 unet.eval()
104 tokenizer.eval() 103 tokenizer.eval()
105 text_encoder.eval() 104 text_encoder.eval()
106 105
@@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks(
174 torch.cuda.empty_cache() 173 torch.cuda.empty_cache()
175 174
176 return TrainingCallbacks( 175 return TrainingCallbacks(
177 on_accum_model=on_accum_model,
178 on_train=on_train, 176 on_train=on_train,
179 on_eval=on_eval, 177 on_eval=on_eval,
180 on_before_optimize=on_before_optimize, 178 on_before_optimize=on_before_optimize,
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 6730dc9..80ffa9c 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -64,17 +64,17 @@ def lora_strategy_callbacks(
64 image_size=sample_image_size, 64 image_size=sample_image_size,
65 ) 65 )
66 66
67 def on_accum_model():
68 return unet
69
70 @contextmanager 67 @contextmanager
71 def on_train(epoch: int): 68 def on_train(epoch: int):
72 tokenizer.train() 69 unet.train()
73 text_encoder.train() 70 text_encoder.train()
71 tokenizer.train()
74 yield 72 yield
75 73
76 @contextmanager 74 @contextmanager
77 def on_eval(): 75 def on_eval():
76 unet.eval()
77 text_encoder.eval()
78 tokenizer.eval() 78 tokenizer.eval()
79 yield 79 yield
80 80
@@ -152,7 +152,6 @@ def lora_strategy_callbacks(
152 torch.cuda.empty_cache() 152 torch.cuda.empty_cache()
153 153
154 return TrainingCallbacks( 154 return TrainingCallbacks(
155 on_accum_model=on_accum_model,
156 on_train=on_train, 155 on_train=on_train,
157 on_eval=on_eval, 156 on_eval=on_eval,
158 on_before_optimize=on_before_optimize, 157 on_before_optimize=on_before_optimize,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 55e9934..6a637c3 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks(
89 else: 89 else:
90 return nullcontext() 90 return nullcontext()
91 91
92 def on_accum_model():
93 return text_encoder.text_model.embeddings.token_override_embedding.params
94
95 @contextmanager 92 @contextmanager
96 def on_train(epoch: int): 93 def on_train(epoch: int):
94 text_encoder.text_model.embeddings.token_override_embedding.params.train()
97 tokenizer.train() 95 tokenizer.train()
98 yield 96 yield
99 97
100 @contextmanager 98 @contextmanager
101 def on_eval(): 99 def on_eval():
100 text_encoder.text_model.embeddings.token_override_embedding.params.eval()
102 tokenizer.eval() 101 tokenizer.eval()
103 102
104 with ema_context(): 103 with ema_context():
@@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks(
166 torch.cuda.empty_cache() 165 torch.cuda.empty_cache()
167 166
168 return TrainingCallbacks( 167 return TrainingCallbacks(
169 on_accum_model=on_accum_model,
170 on_train=on_train, 168 on_train=on_train,
171 on_eval=on_eval, 169 on_eval=on_eval,
172 on_before_optimize=on_before_optimize, 170 on_before_optimize=on_before_optimize,