summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
committerVolpeon <git@volpeon.ink>2023-04-08 07:58:14 +0200
commit5e84594c56237cd2c7d7f80858e5da8c11aa3f89 (patch)
treeb1483a52fb853aecb7b73635cded3cce61edf125
parentFix (diff)
downloadtextual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.gz
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.tar.bz2
textual-inversion-diff-5e84594c56237cd2c7d7f80858e5da8c11aa3f89.zip
Update
-rw-r--r--train_dreambooth.py5
-rw-r--r--train_lora.py75
-rw-r--r--train_ti.py12
-rw-r--r--training/functional.py57
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/lora.py12
-rw-r--r--training/strategy/ti.py2
7 files changed, 80 insertions, 85 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 48921d4..f4d4cbb 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -18,7 +18,6 @@ import transformers
18from util.files import load_config, load_embeddings_from_dir 18from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models 20from training.functional import train, get_models
21from training.lr import plot_metrics
22from training.strategy.dreambooth import dreambooth_strategy 21from training.strategy.dreambooth import dreambooth_strategy
23from training.optimization import get_scheduler 22from training.optimization import get_scheduler
24from training.util import save_args 23from training.util import save_args
@@ -692,7 +691,7 @@ def main():
692 mid_point=args.lr_mid_point, 691 mid_point=args.lr_mid_point,
693 ) 692 )
694 693
695 metrics = trainer( 694 trainer(
696 strategy=dreambooth_strategy, 695 strategy=dreambooth_strategy,
697 project="dreambooth", 696 project="dreambooth",
698 train_dataloader=datamodule.train_dataloader, 697 train_dataloader=datamodule.train_dataloader,
@@ -721,8 +720,6 @@ def main():
721 sample_image_size=args.sample_image_size, 720 sample_image_size=args.sample_image_size,
722 ) 721 )
723 722
724 plot_metrics(metrics, output_dir / "lr.png")
725
726 723
727if __name__ == "__main__": 724if __name__ == "__main__":
728 main() 725 main()
diff --git a/train_lora.py b/train_lora.py
index 9f17495..1626be6 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -19,7 +19,6 @@ import transformers
19from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 21from training.functional import train, add_placeholder_tokens, get_models
22from training.lr import plot_metrics
23from training.strategy.lora import lora_strategy 22from training.strategy.lora import lora_strategy
24from training.optimization import get_scheduler 23from training.optimization import get_scheduler
25from training.util import save_args 24from training.util import save_args
@@ -568,6 +567,9 @@ def parse_args():
568 if len(args.placeholder_tokens) == 0: 567 if len(args.placeholder_tokens) == 0:
569 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 568 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
570 569
570 if len(args.initializer_tokens) == 0:
571 args.initializer_tokens = args.placeholder_tokens.copy()
572
571 if len(args.placeholder_tokens) != len(args.initializer_tokens): 573 if len(args.placeholder_tokens) != len(args.initializer_tokens):
572 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 574 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
573 575
@@ -918,7 +920,7 @@ def main():
918 train_epochs=num_pti_epochs, 920 train_epochs=num_pti_epochs,
919 ) 921 )
920 922
921 metrics = trainer( 923 trainer(
922 strategy=lora_strategy, 924 strategy=lora_strategy,
923 pti_mode=True, 925 pti_mode=True,
924 project="pti", 926 project="pti",
@@ -929,12 +931,13 @@ def main():
929 num_train_epochs=num_pti_epochs, 931 num_train_epochs=num_pti_epochs,
930 gradient_accumulation_steps=args.pti_gradient_accumulation_steps, 932 gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
931 # -- 933 # --
934 group_labels=["emb"],
932 sample_output_dir=pti_sample_output_dir, 935 sample_output_dir=pti_sample_output_dir,
933 checkpoint_output_dir=pti_checkpoint_output_dir, 936 checkpoint_output_dir=pti_checkpoint_output_dir,
934 sample_frequency=pti_sample_frequency, 937 sample_frequency=pti_sample_frequency,
935 ) 938 )
936 939
937 plot_metrics(metrics, pti_output_dir / "lr.png") 940 # embeddings.persist()
938 941
939 # LORA 942 # LORA
940 # -------------------------------------------------------------------------------- 943 # --------------------------------------------------------------------------------
@@ -957,34 +960,39 @@ def main():
957 ) * args.gradient_accumulation_steps 960 ) * args.gradient_accumulation_steps
958 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 961 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps))
959 962
960 lora_optimizer = create_optimizer( 963 params_to_optimize = []
961 [ 964 group_labels = []
962 { 965 if len(args.placeholder_tokens) != 0:
963 "params": ( 966 params_to_optimize.append({
964 param 967 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
965 for param in unet.parameters() 968 "lr": args.learning_rate_text,
966 if param.requires_grad 969 "weight_decay": 0,
967 ), 970 })
968 "lr": args.learning_rate_unet, 971 group_labels.append("emb")
969 }, 972 params_to_optimize += [
970 { 973 {
971 "params": ( 974 "params": (
972 param 975 param
973 for param in itertools.chain( 976 for param in unet.parameters()
974 text_encoder.text_model.encoder.parameters(), 977 if param.requires_grad
975 text_encoder.text_model.final_layer_norm.parameters(), 978 ),
976 ) 979 "lr": args.learning_rate_unet,
977 if param.requires_grad 980 },
978 ), 981 {
979 "lr": args.learning_rate_text, 982 "params": (
980 }, 983 param
981 { 984 for param in itertools.chain(
982 "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), 985 text_encoder.text_model.encoder.parameters(),
983 "lr": args.learning_rate_text, 986 text_encoder.text_model.final_layer_norm.parameters(),
984 "weight_decay": 0, 987 )
985 }, 988 if param.requires_grad
986 ] 989 ),
987 ) 990 "lr": args.learning_rate_text,
991 },
992 ]
993 group_labels += ["unet", "text"]
994
995 lora_optimizer = create_optimizer(params_to_optimize)
988 996
989 lora_lr_scheduler = create_lr_scheduler( 997 lora_lr_scheduler = create_lr_scheduler(
990 gradient_accumulation_steps=args.gradient_accumulation_steps, 998 gradient_accumulation_steps=args.gradient_accumulation_steps,
@@ -993,7 +1001,7 @@ def main():
993 train_epochs=num_train_epochs, 1001 train_epochs=num_train_epochs,
994 ) 1002 )
995 1003
996 metrics = trainer( 1004 trainer(
997 strategy=lora_strategy, 1005 strategy=lora_strategy,
998 project="lora", 1006 project="lora",
999 train_dataloader=lora_datamodule.train_dataloader, 1007 train_dataloader=lora_datamodule.train_dataloader,
@@ -1003,13 +1011,12 @@ def main():
1003 num_train_epochs=num_train_epochs, 1011 num_train_epochs=num_train_epochs,
1004 gradient_accumulation_steps=args.gradient_accumulation_steps, 1012 gradient_accumulation_steps=args.gradient_accumulation_steps,
1005 # -- 1013 # --
1014 group_labels=group_labels,
1006 sample_output_dir=lora_sample_output_dir, 1015 sample_output_dir=lora_sample_output_dir,
1007 checkpoint_output_dir=lora_checkpoint_output_dir, 1016 checkpoint_output_dir=lora_checkpoint_output_dir,
1008 sample_frequency=lora_sample_frequency, 1017 sample_frequency=lora_sample_frequency,
1009 ) 1018 )
1010 1019
1011 plot_metrics(metrics, lora_output_dir / "lr.png")
1012
1013 1020
1014if __name__ == "__main__": 1021if __name__ == "__main__":
1015 main() 1022 main()
diff --git a/train_ti.py b/train_ti.py
index c1c0eed..48858cc 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,6 @@ import transformers
17from util.files import load_config, load_embeddings_from_dir 17from util.files import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, keyword_filter 18from data.csv import VlpnDataModule, keyword_filter
19from training.functional import train, add_placeholder_tokens, get_models 19from training.functional import train, add_placeholder_tokens, get_models
20from training.lr import plot_metrics
21from training.strategy.ti import textual_inversion_strategy 20from training.strategy.ti import textual_inversion_strategy
22from training.optimization import get_scheduler 21from training.optimization import get_scheduler
23from training.util import save_args 22from training.util import save_args
@@ -511,12 +510,12 @@ def parse_args():
511 if isinstance(args.initializer_tokens, str): 510 if isinstance(args.initializer_tokens, str):
512 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 511 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
513 512
514 if len(args.initializer_tokens) == 0:
515 raise ValueError("You must specify --initializer_tokens")
516
517 if len(args.placeholder_tokens) == 0: 513 if len(args.placeholder_tokens) == 0:
518 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 514 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
519 515
516 if len(args.initializer_tokens) == 0:
517 args.initializer_tokens = args.placeholder_tokens.copy()
518
520 if len(args.placeholder_tokens) != len(args.initializer_tokens): 519 if len(args.placeholder_tokens) != len(args.initializer_tokens):
521 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 520 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
522 521
@@ -856,7 +855,7 @@ def main():
856 mid_point=args.lr_mid_point, 855 mid_point=args.lr_mid_point,
857 ) 856 )
858 857
859 metrics = trainer( 858 trainer(
860 project="textual_inversion", 859 project="textual_inversion",
861 train_dataloader=datamodule.train_dataloader, 860 train_dataloader=datamodule.train_dataloader,
862 val_dataloader=datamodule.val_dataloader, 861 val_dataloader=datamodule.val_dataloader,
@@ -864,14 +863,13 @@ def main():
864 lr_scheduler=lr_scheduler, 863 lr_scheduler=lr_scheduler,
865 num_train_epochs=num_train_epochs, 864 num_train_epochs=num_train_epochs,
866 # -- 865 # --
866 group_labels=["emb"],
867 sample_output_dir=sample_output_dir, 867 sample_output_dir=sample_output_dir,
868 sample_frequency=sample_frequency, 868 sample_frequency=sample_frequency,
869 placeholder_tokens=placeholder_tokens, 869 placeholder_tokens=placeholder_tokens,
870 placeholder_token_ids=placeholder_token_ids, 870 placeholder_token_ids=placeholder_token_ids,
871 ) 871 )
872 872
873 plot_metrics(metrics, metrics_output_file)
874
875 if not args.sequential: 873 if not args.sequential:
876 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 874 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
877 else: 875 else:
diff --git a/training/functional.py b/training/functional.py
index 4d83df1..71b2fe9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -36,8 +36,8 @@ def const(result=None):
36class TrainingCallbacks(): 36class TrainingCallbacks():
37 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
39 on_before_optimize: Callable[[float, int], Any] = const() 39 on_before_optimize: Callable[[int], Any] = const()
40 on_after_optimize: Callable[[Any, float], None] = const() 40 on_after_optimize: Callable[[Any, dict[str, float]], None] = const()
41 on_after_epoch: Callable[[], None] = const() 41 on_after_epoch: Callable[[], None] = const()
42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
43 on_sample: Callable[[int], None] = const() 43 on_sample: Callable[[int], None] = const()
@@ -422,6 +422,7 @@ def train_loop(
422 global_step_offset: int = 0, 422 global_step_offset: int = 0,
423 num_epochs: int = 100, 423 num_epochs: int = 100,
424 gradient_accumulation_steps: int = 1, 424 gradient_accumulation_steps: int = 1,
425 group_labels: list[str] = [],
425 callbacks: TrainingCallbacks = TrainingCallbacks(), 426 callbacks: TrainingCallbacks = TrainingCallbacks(),
426): 427):
427 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 428 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
@@ -442,10 +443,6 @@ def train_loop(
442 best_acc = 0.0 443 best_acc = 0.0
443 best_acc_val = 0.0 444 best_acc_val = 0.0
444 445
445 lrs = []
446 losses = []
447 accs = []
448
449 local_progress_bar = tqdm( 446 local_progress_bar = tqdm(
450 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 447 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
451 disable=not accelerator.is_local_main_process, 448 disable=not accelerator.is_local_main_process,
@@ -496,6 +493,8 @@ def train_loop(
496 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 493 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
497 local_progress_bar.reset() 494 local_progress_bar.reset()
498 495
496 logs = {}
497
499 with on_train(epoch): 498 with on_train(epoch):
500 for step, batch in enumerate(train_dataloader): 499 for step, batch in enumerate(train_dataloader):
501 loss, acc, bsz = loss_step(step, batch, cache) 500 loss, acc, bsz = loss_step(step, batch, cache)
@@ -506,31 +505,36 @@ def train_loop(
506 avg_loss.update(loss.detach_(), bsz) 505 avg_loss.update(loss.detach_(), bsz)
507 avg_acc.update(acc.detach_(), bsz) 506 avg_acc.update(acc.detach_(), bsz)
508 507
509 lr = lr_scheduler.get_last_lr()[0]
510 if torch.is_tensor(lr):
511 lr = lr.item()
512
513 logs = { 508 logs = {
514 "train/loss": avg_loss.avg.item(), 509 "train/loss": avg_loss.avg.item(),
515 "train/acc": avg_acc.avg.item(), 510 "train/acc": avg_acc.avg.item(),
516 "train/cur_loss": loss.item(), 511 "train/cur_loss": loss.item(),
517 "train/cur_acc": acc.item(), 512 "train/cur_acc": acc.item(),
518 "lr": lr,
519 } 513 }
520 if isDadaptation: 514
521 logs["lr/d*lr"] = lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] 515 lrs: dict[str, float] = {}
516 for i, lr in enumerate(lr_scheduler.get_last_lr()):
517 if torch.is_tensor(lr):
518 lr = lr.item()
519 label = group_labels[i] if i < len(group_labels) else f"{i}"
520 logs[f"lr/{label}"] = lr
521 if isDadaptation:
522 lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
523 logs[f"d*lr/{label}"] = lr
524 lrs[label] = lr
525
522 logs.update(on_log()) 526 logs.update(on_log())
523 527
524 local_progress_bar.set_postfix(**logs) 528 local_progress_bar.set_postfix(**logs)
525 529
526 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 530 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)):
527 before_optimize_result = on_before_optimize(lr, epoch) 531 before_optimize_result = on_before_optimize(epoch)
528 532
529 optimizer.step() 533 optimizer.step()
530 lr_scheduler.step() 534 lr_scheduler.step()
531 optimizer.zero_grad(set_to_none=True) 535 optimizer.zero_grad(set_to_none=True)
532 536
533 on_after_optimize(before_optimize_result, lr) 537 on_after_optimize(before_optimize_result, lrs)
534 538
535 local_progress_bar.update(1) 539 local_progress_bar.update(1)
536 global_progress_bar.update(1) 540 global_progress_bar.update(1)
@@ -544,15 +548,6 @@ def train_loop(
544 548
545 accelerator.wait_for_everyone() 549 accelerator.wait_for_everyone()
546 550
547 if isDadaptation:
548 lr = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
549 else:
550 lr = lr_scheduler.get_last_lr()[0]
551 if torch.is_tensor(lr):
552 lr = lr.item()
553
554 lrs.append(lr)
555
556 on_after_epoch() 551 on_after_epoch()
557 552
558 if val_dataloader is not None: 553 if val_dataloader is not None:
@@ -597,9 +592,6 @@ def train_loop(
597 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 592 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
598 on_checkpoint(global_step + global_step_offset, "milestone") 593 on_checkpoint(global_step + global_step_offset, "milestone")
599 best_acc_val = avg_acc_val.avg.item() 594 best_acc_val = avg_acc_val.avg.item()
600
601 losses.append(avg_loss_val.avg.item())
602 accs.append(avg_acc_val.avg.item())
603 else: 595 else:
604 if accelerator.is_main_process: 596 if accelerator.is_main_process:
605 if avg_acc.avg.item() > best_acc and milestone_checkpoints: 597 if avg_acc.avg.item() > best_acc and milestone_checkpoints:
@@ -611,9 +603,6 @@ def train_loop(
611 on_checkpoint(global_step + global_step_offset, "milestone") 603 on_checkpoint(global_step + global_step_offset, "milestone")
612 best_acc = avg_acc.avg.item() 604 best_acc = avg_acc.avg.item()
613 605
614 losses.append(avg_loss.avg.item())
615 accs.append(avg_acc.avg.item())
616
617 # Create the pipeline using using the trained modules and save it. 606 # Create the pipeline using using the trained modules and save it.
618 if accelerator.is_main_process: 607 if accelerator.is_main_process:
619 print("Finished!") 608 print("Finished!")
@@ -626,8 +615,6 @@ def train_loop(
626 on_checkpoint(global_step + global_step_offset, "end") 615 on_checkpoint(global_step + global_step_offset, "end")
627 raise KeyboardInterrupt 616 raise KeyboardInterrupt
628 617
629 return lrs, losses, accs
630
631 618
632def train( 619def train(
633 accelerator: Accelerator, 620 accelerator: Accelerator,
@@ -646,6 +633,7 @@ def train(
646 no_val: bool = False, 633 no_val: bool = False,
647 num_train_epochs: int = 100, 634 num_train_epochs: int = 100,
648 gradient_accumulation_steps: int = 1, 635 gradient_accumulation_steps: int = 1,
636 group_labels: list[str] = [],
649 sample_frequency: int = 20, 637 sample_frequency: int = 20,
650 checkpoint_frequency: int = 50, 638 checkpoint_frequency: int = 50,
651 milestone_checkpoints: bool = True, 639 milestone_checkpoints: bool = True,
@@ -692,7 +680,7 @@ def train(
692 if accelerator.is_main_process: 680 if accelerator.is_main_process:
693 accelerator.init_trackers(project) 681 accelerator.init_trackers(project)
694 682
695 metrics = train_loop( 683 train_loop(
696 accelerator=accelerator, 684 accelerator=accelerator,
697 optimizer=optimizer, 685 optimizer=optimizer,
698 lr_scheduler=lr_scheduler, 686 lr_scheduler=lr_scheduler,
@@ -705,10 +693,9 @@ def train(
705 global_step_offset=global_step_offset, 693 global_step_offset=global_step_offset,
706 num_epochs=num_train_epochs, 694 num_epochs=num_train_epochs,
707 gradient_accumulation_steps=gradient_accumulation_steps, 695 gradient_accumulation_steps=gradient_accumulation_steps,
696 group_labels=group_labels,
708 callbacks=callbacks, 697 callbacks=callbacks,
709 ) 698 )
710 699
711 accelerator.end_training() 700 accelerator.end_training()
712 accelerator.free_memory() 701 accelerator.free_memory()
713
714 return metrics
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0286673..695174a 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -106,7 +106,7 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 106 with ema_context():
107 yield 107 yield
108 108
109 def on_before_optimize(lr: float, epoch: int): 109 def on_before_optimize(epoch: int):
110 params_to_clip = [unet.parameters()] 110 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 111 if epoch < train_text_encoder_epochs:
112 params_to_clip.append(text_encoder.parameters()) 112 params_to_clip.append(text_encoder.parameters())
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 912ff26..89269c0 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -79,10 +79,14 @@ def lora_strategy_callbacks(
79 tokenizer.eval() 79 tokenizer.eval()
80 yield 80 yield
81 81
82 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(epoch: int):
83 if not pti_mode: 83 if not pti_mode:
84 accelerator.clip_grad_norm_( 84 accelerator.clip_grad_norm_(
85 itertools.chain(unet.parameters(), text_encoder.parameters()), 85 itertools.chain(
86 unet.parameters(),
87 text_encoder.text_model.encoder.parameters(),
88 text_encoder.text_model.final_layer_norm.parameters(),
89 ),
86 max_grad_norm 90 max_grad_norm
87 ) 91 )
88 92
@@ -95,7 +99,9 @@ def lora_strategy_callbacks(
95 return torch.stack(params) if len(params) != 0 else None 99 return torch.stack(params) if len(params) != 0 else None
96 100
97 @torch.no_grad() 101 @torch.no_grad()
98 def on_after_optimize(w, lr: float): 102 def on_after_optimize(w, lrs: dict[str, float]):
103 lr = lrs["emb"] or lrs["0"]
104
99 if use_emb_decay and w is not None: 105 if use_emb_decay and w is not None:
100 lambda_ = emb_decay * lr 106 lambda_ = emb_decay * lr
101 107
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6a637c3..d735dac 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(lr: float, epoch: int): 107 def on_before_optimize(epoch: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p