summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py22
-rw-r--r--train_lora.py13
-rw-r--r--train_ti.py13
-rw-r--r--training/functional.py58
-rw-r--r--training/strategy/lora.py8
-rw-r--r--training/util.py22
-rw-r--r--util/ti.py24
7 files changed, 119 insertions, 41 deletions
diff --git a/infer.py b/infer.py
index 4648c0a..7346de9 100644
--- a/infer.py
+++ b/infer.py
@@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings
35from models.clip.tokenizer import MultiCLIPTokenizer 35from models.clip.tokenizer import MultiCLIPTokenizer
36from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 36from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
37from util.files import load_config, load_embeddings_from_dir 37from util.files import load_config, load_embeddings_from_dir
38from util.ti import load_embeddings
38 39
39 40
40torch.backends.cuda.matmul.allow_tf32 = True 41torch.backends.cuda.matmul.allow_tf32 = True
@@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}):
229 json.dump(info, f, indent=4) 230 json.dump(info, f, indent=4)
230 231
231 232
232def load_embeddings(pipeline, embeddings_dir): 233def load_embeddings_dir(pipeline, embeddings_dir):
233 added_tokens, added_ids = load_embeddings_from_dir( 234 added_tokens, added_ids = load_embeddings_from_dir(
234 pipeline.tokenizer, 235 pipeline.tokenizer,
235 pipeline.text_encoder.text_model.embeddings, 236 pipeline.text_encoder.text_model.embeddings,
@@ -258,6 +259,9 @@ def load_lora(pipeline, path):
258 text_encoder_lora_ds = { 259 text_encoder_lora_ds = {
259 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k 260 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
260 } 261 }
262 ti_lora_ds = {
263 k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k
264 }
261 265
262 unet_config = LoraConfig(**lora_config["peft_config"]) 266 unet_config = LoraConfig(**lora_config["peft_config"])
263 pipeline.unet = LoraModel(unet_config, pipeline.unet) 267 pipeline.unet = LoraModel(unet_config, pipeline.unet)
@@ -268,6 +272,18 @@ def load_lora(pipeline, path):
268 pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) 272 pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder)
269 set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) 273 set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds)
270 274
275 tokens = [k for k, _ in ti_lora_ds]
276 token_embeddings = [v for _, v in ti_lora_ds]
277
278 added_tokens, added_ids = load_embeddings(
279 tokenizer=pipeline.tokenizer,
280 embeddings=pipeline.text_encoder.text_model.embeddings,
281 tokens=tokens,
282 token_embeddings=token_embeddings,
283 )
284 pipeline.text_encoder.text_model.embeddings.persist()
285 print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}")
286
271 return 287 return
272 288
273 289
@@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd):
435 return True 451 return True
436 452
437 if elements[0] == 'reload_embeddings': 453 if elements[0] == 'reload_embeddings':
438 load_embeddings(self.pipeline, self.ti_embeddings_dir) 454 load_embeddings_dir(self.pipeline, self.ti_embeddings_dir)
439 return 455 return
440 456
441 try: 457 try:
@@ -475,7 +491,7 @@ def main():
475 491
476 pipeline = create_pipeline(args.model, dtype) 492 pipeline = create_pipeline(args.model, dtype)
477 493
478 load_embeddings(pipeline, args.ti_embeddings_dir) 494 load_embeddings_dir(pipeline, args.ti_embeddings_dir)
479 load_lora(pipeline, args.lora_embedding) 495 load_lora(pipeline, args.lora_embedding)
480 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 496 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
481 497
diff --git a/train_lora.py b/train_lora.py
index c197206..9cf17c7 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter
23from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
24from training.strategy.lora import lora_strategy 24from training.strategy.lora import lora_strategy
25from training.optimization import get_scheduler 25from training.optimization import get_scheduler
26from training.util import save_args 26from training.util import AverageMeter, save_args
27 27
28# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 28# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
29UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] 29UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
@@ -1035,6 +1035,11 @@ def main():
1035 lr_warmup_epochs = args.lr_warmup_epochs 1035 lr_warmup_epochs = args.lr_warmup_epochs
1036 lr_cycles = args.lr_cycles 1036 lr_cycles = args.lr_cycles
1037 1037
1038 avg_loss = AverageMeter()
1039 avg_acc = AverageMeter()
1040 avg_loss_val = AverageMeter()
1041 avg_acc_val = AverageMeter()
1042
1038 while True: 1043 while True:
1039 if len(auto_cycles) != 0: 1044 if len(auto_cycles) != 0:
1040 response = auto_cycles.pop(0) 1045 response = auto_cycles.pop(0)
@@ -1122,7 +1127,7 @@ def main():
1122 warmup_epochs=lr_warmup_epochs, 1127 warmup_epochs=lr_warmup_epochs,
1123 ) 1128 )
1124 1129
1125 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" 1130 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}"
1126 1131
1127 trainer( 1132 trainer(
1128 strategy=lora_strategy, 1133 strategy=lora_strategy,
@@ -1142,6 +1147,10 @@ def main():
1142 sample_frequency=lora_sample_frequency, 1147 sample_frequency=lora_sample_frequency,
1143 offset_noise_strength=args.offset_noise_strength, 1148 offset_noise_strength=args.offset_noise_strength,
1144 no_val=args.valid_set_size == 0, 1149 no_val=args.valid_set_size == 0,
1150 avg_loss=avg_loss,
1151 avg_acc=avg_acc,
1152 avg_loss_val=avg_loss_val,
1153 avg_acc_val=avg_acc_val,
1145 ) 1154 )
1146 1155
1147 training_iter += 1 1156 training_iter += 1
diff --git a/train_ti.py b/train_ti.py
index d1e5467..fce4a5e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter
23from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
24from training.strategy.ti import textual_inversion_strategy 24from training.strategy.ti import textual_inversion_strategy
25from training.optimization import get_scheduler 25from training.optimization import get_scheduler
26from training.util import save_args 26from training.util import AverageMeter, save_args
27 27
28logger = get_logger(__name__) 28logger = get_logger(__name__)
29 29
@@ -920,6 +920,11 @@ def main():
920 lr_warmup_epochs = args.lr_warmup_epochs 920 lr_warmup_epochs = args.lr_warmup_epochs
921 lr_cycles = args.lr_cycles 921 lr_cycles = args.lr_cycles
922 922
923 avg_loss = AverageMeter()
924 avg_acc = AverageMeter()
925 avg_loss_val = AverageMeter()
926 avg_acc_val = AverageMeter()
927
923 while True: 928 while True:
924 if len(auto_cycles) != 0: 929 if len(auto_cycles) != 0:
925 response = auto_cycles.pop(0) 930 response = auto_cycles.pop(0)
@@ -977,7 +982,7 @@ def main():
977 mid_point=args.lr_mid_point, 982 mid_point=args.lr_mid_point,
978 ) 983 )
979 984
980 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" 985 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}"
981 986
982 trainer( 987 trainer(
983 train_dataloader=datamodule.train_dataloader, 988 train_dataloader=datamodule.train_dataloader,
@@ -994,6 +999,10 @@ def main():
994 sample_frequency=sample_frequency, 999 sample_frequency=sample_frequency,
995 placeholder_tokens=placeholder_tokens, 1000 placeholder_tokens=placeholder_tokens,
996 placeholder_token_ids=placeholder_token_ids, 1001 placeholder_token_ids=placeholder_token_ids,
1002 avg_loss=avg_loss,
1003 avg_acc=avg_acc,
1004 avg_loss_val=avg_loss_val,
1005 avg_acc_val=avg_acc_val,
997 ) 1006 )
998 1007
999 training_iter += 1 1008 training_iter += 1
diff --git a/training/functional.py b/training/functional.py
index 695a24f..3036ed9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -461,6 +461,10 @@ def train_loop(
461 num_epochs: int = 100, 461 num_epochs: int = 100,
462 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
463 group_labels: list[str] = [], 463 group_labels: list[str] = [],
464 avg_loss: AverageMeter = AverageMeter(),
465 avg_acc: AverageMeter = AverageMeter(),
466 avg_loss_val: AverageMeter = AverageMeter(),
467 avg_acc_val: AverageMeter = AverageMeter(),
464 callbacks: TrainingCallbacks = TrainingCallbacks(), 468 callbacks: TrainingCallbacks = TrainingCallbacks(),
465): 469):
466 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
@@ -472,14 +476,8 @@ def train_loop(
472 global_step = 0 476 global_step = 0
473 cache = {} 477 cache = {}
474 478
475 avg_loss = AverageMeter() 479 best_acc = avg_acc.avg
476 avg_acc = AverageMeter() 480 best_acc_val = avg_acc_val.avg
477
478 avg_loss_val = AverageMeter()
479 avg_acc_val = AverageMeter()
480
481 best_acc = 0.0
482 best_acc_val = 0.0
483 481
484 local_progress_bar = tqdm( 482 local_progress_bar = tqdm(
485 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 483 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -544,12 +542,12 @@ def train_loop(
544 542
545 accelerator.backward(loss) 543 accelerator.backward(loss)
546 544
547 avg_loss.update(loss.detach_(), bsz) 545 avg_loss.update(loss.item(), bsz)
548 avg_acc.update(acc.detach_(), bsz) 546 avg_acc.update(acc.item(), bsz)
549 547
550 logs = { 548 logs = {
551 "train/loss": avg_loss.avg.item(), 549 "train/loss": avg_loss.avg,
552 "train/acc": avg_acc.avg.item(), 550 "train/acc": avg_acc.avg,
553 "train/cur_loss": loss.item(), 551 "train/cur_loss": loss.item(),
554 "train/cur_acc": acc.item(), 552 "train/cur_acc": acc.item(),
555 } 553 }
@@ -603,47 +601,47 @@ def train_loop(
603 loss = loss.detach_() 601 loss = loss.detach_()
604 acc = acc.detach_() 602 acc = acc.detach_()
605 603
606 cur_loss_val.update(loss, bsz) 604 cur_loss_val.update(loss.item(), bsz)
607 cur_acc_val.update(acc, bsz) 605 cur_acc_val.update(acc.item(), bsz)
608 606
609 avg_loss_val.update(loss, bsz) 607 avg_loss_val.update(loss.item(), bsz)
610 avg_acc_val.update(acc, bsz) 608 avg_acc_val.update(acc.item(), bsz)
611 609
612 local_progress_bar.update(1) 610 local_progress_bar.update(1)
613 global_progress_bar.update(1) 611 global_progress_bar.update(1)
614 612
615 logs = { 613 logs = {
616 "val/loss": avg_loss_val.avg.item(), 614 "val/loss": avg_loss_val.avg,
617 "val/acc": avg_acc_val.avg.item(), 615 "val/acc": avg_acc_val.avg,
618 "val/cur_loss": loss.item(), 616 "val/cur_loss": loss.item(),
619 "val/cur_acc": acc.item(), 617 "val/cur_acc": acc.item(),
620 } 618 }
621 local_progress_bar.set_postfix(**logs) 619 local_progress_bar.set_postfix(**logs)
622 620
623 logs["val/cur_loss"] = cur_loss_val.avg.item() 621 logs["val/cur_loss"] = cur_loss_val.avg
624 logs["val/cur_acc"] = cur_acc_val.avg.item() 622 logs["val/cur_acc"] = cur_acc_val.avg
625 623
626 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
627 625
628 if accelerator.is_main_process: 626 if accelerator.is_main_process:
629 if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: 627 if avg_acc_val.avg > best_acc_val and milestone_checkpoints:
630 local_progress_bar.clear() 628 local_progress_bar.clear()
631 global_progress_bar.clear() 629 global_progress_bar.clear()
632 630
633 accelerator.print( 631 accelerator.print(
634 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}")
635 on_checkpoint(global_step, "milestone") 633 on_checkpoint(global_step, "milestone")
636 best_acc_val = avg_acc_val.avg.item() 634 best_acc_val = avg_acc_val.avg
637 else: 635 else:
638 if accelerator.is_main_process: 636 if accelerator.is_main_process:
639 if avg_acc.avg.item() > best_acc and milestone_checkpoints: 637 if avg_acc.avg > best_acc and milestone_checkpoints:
640 local_progress_bar.clear() 638 local_progress_bar.clear()
641 global_progress_bar.clear() 639 global_progress_bar.clear()
642 640
643 accelerator.print( 641 accelerator.print(
644 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}")
645 on_checkpoint(global_step, "milestone") 643 on_checkpoint(global_step, "milestone")
646 best_acc = avg_acc.avg.item() 644 best_acc = avg_acc.avg
647 645
648 # Create the pipeline using using the trained modules and save it. 646 # Create the pipeline using using the trained modules and save it.
649 if accelerator.is_main_process: 647 if accelerator.is_main_process:
@@ -688,6 +686,10 @@ def train(
688 offset_noise_strength: float = 0.15, 686 offset_noise_strength: float = 0.15,
689 disc: Optional[ConvNeXtDiscriminator] = None, 687 disc: Optional[ConvNeXtDiscriminator] = None,
690 min_snr_gamma: int = 5, 688 min_snr_gamma: int = 5,
689 avg_loss: AverageMeter = AverageMeter(),
690 avg_acc: AverageMeter = AverageMeter(),
691 avg_loss_val: AverageMeter = AverageMeter(),
692 avg_acc_val: AverageMeter = AverageMeter(),
691 **kwargs, 693 **kwargs,
692): 694):
693 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 695 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
@@ -737,6 +739,10 @@ def train(
737 num_epochs=num_train_epochs, 739 num_epochs=num_train_epochs,
738 gradient_accumulation_steps=gradient_accumulation_steps, 740 gradient_accumulation_steps=gradient_accumulation_steps,
739 group_labels=group_labels, 741 group_labels=group_labels,
742 avg_loss=avg_loss,
743 avg_acc=avg_acc,
744 avg_loss_val=avg_loss_val,
745 avg_acc_val=avg_acc_val,
740 callbacks=callbacks, 746 callbacks=callbacks,
741 ) 747 )
742 748
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1f0a117..3f4dbbc 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -138,6 +138,14 @@ def lora_strategy_callbacks(
138 state_dict.update(text_encoder_state_dict) 138 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
140 140
141 if len(placeholder_tokens) != 0:
142 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 }
147 state_dict.update(ti_state_dict)
148
141 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
142 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 150 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
143 json.dump(lora_config, f) 151 json.dump(lora_config, f)
diff --git a/training/util.py b/training/util.py
index 8bd8a83..61f1533 100644
--- a/training/util.py
+++ b/training/util.py
@@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}):
16 16
17 17
18class AverageMeter: 18class AverageMeter:
19 avg: Any 19 def __init__(self, inv_gamma=1.0, power=2 / 3):
20 20 self.inv_gamma = inv_gamma
21 def __init__(self, name=None): 21 self.power = power
22 self.name = name
23 self.reset() 22 self.reset()
24 23
25 def reset(self): 24 def reset(self):
26 self.sum = self.count = self.avg = 0 25 self.step = 0
26 self.avg = 0
27
28 def get_decay(self):
29 if self.step <= 0:
30 return 1
31
32 return (self.step / self.inv_gamma) ** -self.power
27 33
28 def update(self, val, n=1): 34 def update(self, val, n=1):
29 self.sum += val * n 35 for _ in range(n):
30 self.count += n 36 self.step += n
31 self.avg = self.sum / self.count 37 self.avg += (val - self.avg) * self.get_decay()
32 38
33 39
34class EMAModel(EMAModel_): 40class EMAModel(EMAModel_):
diff --git a/util/ti.py b/util/ti.py
new file mode 100644
index 0000000..4cc732e
--- /dev/null
+++ b/util/ti.py
@@ -0,0 +1,24 @@
1from pathlib import Path
2
3import torch
4
5from models.clip.embeddings import ManagedCLIPTextEmbeddings
6from models.clip.tokenizer import MultiCLIPTokenizer
7
8
9def load_embeddings(
10 tokenizer: MultiCLIPTokenizer,
11 embeddings: ManagedCLIPTextEmbeddings,
12 tokens: list[str],
13 token_embeddings: torch.FloatTensor,
14):
15 num_vectors = [embedding.shape[0] for embedding in token_embeddings]
16
17 token_ids = tokenizer.add_multi_tokens(tokens, num_vectors)
18
19 embeddings.resize(len(tokenizer))
20
21 for (new_id, embeds) in zip(token_ids, token_embeddings):
22 embeddings.add_embed(new_id, embeds)
23
24 return tokens, token_ids