summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
committerVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
commit7da4f0485032bb8b8acfc678546ffcea3a23a44b (patch)
tree1e7880189df21132861114b5dbf4c614405c9855
parentFix PTI (diff)
downloadtextual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.gz
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.bz2
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.zip
Update
-rw-r--r--train_lora.py6
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/lora.py8
-rw-r--r--training/strategy/ti.py2
4 files changed, 12 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py
index 0d8b8cb..1d1485d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -873,7 +873,6 @@ def main():
873 seed=args.seed, 873 seed=args.seed,
874 guidance_scale=args.guidance_scale, 874 guidance_scale=args.guidance_scale,
875 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 875 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
876 offset_noise_strength=args.offset_noise_strength,
877 sample_scheduler=sample_scheduler, 876 sample_scheduler=sample_scheduler,
878 sample_batch_size=args.sample_batch_size, 877 sample_batch_size=args.sample_batch_size,
879 sample_num_batches=args.sample_batches, 878 sample_num_batches=args.sample_batches,
@@ -984,13 +983,14 @@ def main():
984 lr_scheduler=pti_lr_scheduler, 983 lr_scheduler=pti_lr_scheduler,
985 num_train_epochs=num_train_epochs, 984 num_train_epochs=num_train_epochs,
986 gradient_accumulation_steps=args.gradient_accumulation_steps, 985 gradient_accumulation_steps=args.gradient_accumulation_steps,
987 cycle=1, 986 cycle=0,
988 pti_mode=True, 987 pti_mode=True,
989 # -- 988 # --
990 group_labels=["emb"], 989 group_labels=["emb"],
991 sample_output_dir=pti_sample_output_dir, 990 sample_output_dir=pti_sample_output_dir,
992 checkpoint_output_dir=pti_checkpoint_output_dir, 991 checkpoint_output_dir=pti_checkpoint_output_dir,
993 sample_frequency=pti_sample_frequency, 992 sample_frequency=pti_sample_frequency,
993 offset_noise_strength=0,
994 no_val=True, 994 no_val=True,
995 ) 995 )
996 996
@@ -1132,11 +1132,13 @@ def main():
1132 gradient_accumulation_steps=args.gradient_accumulation_steps, 1132 gradient_accumulation_steps=args.gradient_accumulation_steps,
1133 global_step_offset=training_iter * num_train_steps, 1133 global_step_offset=training_iter * num_train_steps,
1134 cycle=training_iter, 1134 cycle=training_iter,
1135 train_text_encoder_cycles=args.train_text_encoder_cycles,
1135 # -- 1136 # --
1136 group_labels=group_labels, 1137 group_labels=group_labels,
1137 sample_output_dir=lora_sample_output_dir, 1138 sample_output_dir=lora_sample_output_dir,
1138 checkpoint_output_dir=lora_checkpoint_output_dir, 1139 checkpoint_output_dir=lora_checkpoint_output_dir,
1139 sample_frequency=lora_sample_frequency, 1140 sample_frequency=lora_sample_frequency,
1141 offset_noise_strength=args.offset_noise_strength,
1140 no_val=args.valid_set_size == 0, 1142 no_val=args.valid_set_size == 0,
1141 ) 1143 )
1142 1144
diff --git a/training/functional.py b/training/functional.py
index c6ceb20..695a24f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -456,7 +456,7 @@ def train_loop(
456 sample_frequency: int = 10, 456 sample_frequency: int = 10,
457 checkpoint_frequency: int = 50, 457 checkpoint_frequency: int = 50,
458 milestone_checkpoints: bool = True, 458 milestone_checkpoints: bool = True,
459 cycle: int = 1, 459 cycle: int = 0,
460 global_step_offset: int = 0, 460 global_step_offset: int = 0,
461 num_epochs: int = 100, 461 num_epochs: int = 100,
462 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
@@ -537,7 +537,7 @@ def train_loop(
537 537
538 logs = {} 538 logs = {}
539 539
540 with on_train(epoch): 540 with on_train(cycle):
541 for step, batch in enumerate(train_dataloader): 541 for step, batch in enumerate(train_dataloader):
542 loss, acc, bsz = loss_step(step, batch, cache) 542 loss, acc, bsz = loss_step(step, batch, cache)
543 loss /= gradient_accumulation_steps 543 loss /= gradient_accumulation_steps
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 5c3012e..1f0a117 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -35,6 +35,7 @@ def lora_strategy_callbacks(
35 placeholder_tokens: list[str], 35 placeholder_tokens: list[str],
36 placeholder_token_ids: list[list[int]], 36 placeholder_token_ids: list[list[int]],
37 pti_mode: bool = False, 37 pti_mode: bool = False,
38 train_text_encoder_cycles: int = 99999,
38 use_emb_decay: bool = False, 39 use_emb_decay: bool = False,
39 emb_decay_target: float = 0.4, 40 emb_decay_target: float = 0.4,
40 emb_decay: float = 1e-2, 41 emb_decay: float = 1e-2,
@@ -66,10 +67,11 @@ def lora_strategy_callbacks(
66 ) 67 )
67 68
68 @contextmanager 69 @contextmanager
69 def on_train(epoch: int): 70 def on_train(cycle: int):
70 unet.train() 71 unet.train()
71 text_encoder.train() 72 if cycle < train_text_encoder_cycles:
72 tokenizer.train() 73 text_encoder.train()
74 tokenizer.train()
73 yield 75 yield
74 76
75 @contextmanager 77 @contextmanager
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index f330cb7..6bc1d7d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks(
90 return nullcontext() 90 return nullcontext()
91 91
92 @contextmanager 92 @contextmanager
93 def on_train(epoch: int): 93 def on_train(cycle: int):
94 text_encoder.train() 94 text_encoder.train()
95 tokenizer.train() 95 tokenizer.train()
96 yield 96 yield