diff options
-rw-r--r-- | infer.py | 22 | ||||
-rw-r--r-- | train_lora.py | 13 | ||||
-rw-r--r-- | train_ti.py | 13 | ||||
-rw-r--r-- | training/functional.py | 58 | ||||
-rw-r--r-- | training/strategy/lora.py | 8 | ||||
-rw-r--r-- | training/util.py | 22 | ||||
-rw-r--r-- | util/ti.py | 24 |
7 files changed, 119 insertions, 41 deletions
@@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings | |||
35 | from models.clip.tokenizer import MultiCLIPTokenizer | 35 | from models.clip.tokenizer import MultiCLIPTokenizer |
36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 36 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
37 | from util.files import load_config, load_embeddings_from_dir | 37 | from util.files import load_config, load_embeddings_from_dir |
38 | from util.ti import load_embeddings | ||
38 | 39 | ||
39 | 40 | ||
40 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}): | |||
229 | json.dump(info, f, indent=4) | 230 | json.dump(info, f, indent=4) |
230 | 231 | ||
231 | 232 | ||
232 | def load_embeddings(pipeline, embeddings_dir): | 233 | def load_embeddings_dir(pipeline, embeddings_dir): |
233 | added_tokens, added_ids = load_embeddings_from_dir( | 234 | added_tokens, added_ids = load_embeddings_from_dir( |
234 | pipeline.tokenizer, | 235 | pipeline.tokenizer, |
235 | pipeline.text_encoder.text_model.embeddings, | 236 | pipeline.text_encoder.text_model.embeddings, |
@@ -258,6 +259,9 @@ def load_lora(pipeline, path): | |||
258 | text_encoder_lora_ds = { | 259 | text_encoder_lora_ds = { |
259 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k |
260 | } | 261 | } |
262 | ti_lora_ds = { | ||
263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | ||
264 | } | ||
261 | 265 | ||
262 | unet_config = LoraConfig(**lora_config["peft_config"]) | 266 | unet_config = LoraConfig(**lora_config["peft_config"]) |
263 | pipeline.unet = LoraModel(unet_config, pipeline.unet) | 267 | pipeline.unet = LoraModel(unet_config, pipeline.unet) |
@@ -268,6 +272,18 @@ def load_lora(pipeline, path): | |||
268 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) | 272 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) |
269 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) | 273 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) |
270 | 274 | ||
275 | tokens = [k for k, _ in ti_lora_ds] | ||
276 | token_embeddings = [v for _, v in ti_lora_ds] | ||
277 | |||
278 | added_tokens, added_ids = load_embeddings( | ||
279 | tokenizer=pipeline.tokenizer, | ||
280 | embeddings=pipeline.text_encoder.text_model.embeddings, | ||
281 | tokens=tokens, | ||
282 | token_embeddings=token_embeddings, | ||
283 | ) | ||
284 | pipeline.text_encoder.text_model.embeddings.persist() | ||
285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | ||
286 | |||
271 | return | 287 | return |
272 | 288 | ||
273 | 289 | ||
@@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd): | |||
435 | return True | 451 | return True |
436 | 452 | ||
437 | if elements[0] == 'reload_embeddings': | 453 | if elements[0] == 'reload_embeddings': |
438 | load_embeddings(self.pipeline, self.ti_embeddings_dir) | 454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
439 | return | 455 | return |
440 | 456 | ||
441 | try: | 457 | try: |
@@ -475,7 +491,7 @@ def main(): | |||
475 | 491 | ||
476 | pipeline = create_pipeline(args.model, dtype) | 492 | pipeline = create_pipeline(args.model, dtype) |
477 | 493 | ||
478 | load_embeddings(pipeline, args.ti_embeddings_dir) | 494 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
479 | load_lora(pipeline, args.lora_embedding) | 495 | load_lora(pipeline, args.lora_embedding) |
480 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
481 | 497 | ||
diff --git a/train_lora.py b/train_lora.py index c197206..9cf17c7 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
24 | from training.strategy.lora import lora_strategy | 24 | from training.strategy.lora import lora_strategy |
25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
27 | 27 | ||
28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 28 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] | 29 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] |
@@ -1035,6 +1035,11 @@ def main(): | |||
1035 | lr_warmup_epochs = args.lr_warmup_epochs | 1035 | lr_warmup_epochs = args.lr_warmup_epochs |
1036 | lr_cycles = args.lr_cycles | 1036 | lr_cycles = args.lr_cycles |
1037 | 1037 | ||
1038 | avg_loss = AverageMeter() | ||
1039 | avg_acc = AverageMeter() | ||
1040 | avg_loss_val = AverageMeter() | ||
1041 | avg_acc_val = AverageMeter() | ||
1042 | |||
1038 | while True: | 1043 | while True: |
1039 | if len(auto_cycles) != 0: | 1044 | if len(auto_cycles) != 0: |
1040 | response = auto_cycles.pop(0) | 1045 | response = auto_cycles.pop(0) |
@@ -1122,7 +1127,7 @@ def main(): | |||
1122 | warmup_epochs=lr_warmup_epochs, | 1127 | warmup_epochs=lr_warmup_epochs, |
1123 | ) | 1128 | ) |
1124 | 1129 | ||
1125 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1130 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" |
1126 | 1131 | ||
1127 | trainer( | 1132 | trainer( |
1128 | strategy=lora_strategy, | 1133 | strategy=lora_strategy, |
@@ -1142,6 +1147,10 @@ def main(): | |||
1142 | sample_frequency=lora_sample_frequency, | 1147 | sample_frequency=lora_sample_frequency, |
1143 | offset_noise_strength=args.offset_noise_strength, | 1148 | offset_noise_strength=args.offset_noise_strength, |
1144 | no_val=args.valid_set_size == 0, | 1149 | no_val=args.valid_set_size == 0, |
1150 | avg_loss=avg_loss, | ||
1151 | avg_acc=avg_acc, | ||
1152 | avg_loss_val=avg_loss_val, | ||
1153 | avg_acc_val=avg_acc_val, | ||
1145 | ) | 1154 | ) |
1146 | 1155 | ||
1147 | training_iter += 1 | 1156 | training_iter += 1 |
diff --git a/train_ti.py b/train_ti.py index d1e5467..fce4a5e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter | |||
23 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
24 | from training.strategy.ti import textual_inversion_strategy | 24 | from training.strategy.ti import textual_inversion_strategy |
25 | from training.optimization import get_scheduler | 25 | from training.optimization import get_scheduler |
26 | from training.util import save_args | 26 | from training.util import AverageMeter, save_args |
27 | 27 | ||
28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
29 | 29 | ||
@@ -920,6 +920,11 @@ def main(): | |||
920 | lr_warmup_epochs = args.lr_warmup_epochs | 920 | lr_warmup_epochs = args.lr_warmup_epochs |
921 | lr_cycles = args.lr_cycles | 921 | lr_cycles = args.lr_cycles |
922 | 922 | ||
923 | avg_loss = AverageMeter() | ||
924 | avg_acc = AverageMeter() | ||
925 | avg_loss_val = AverageMeter() | ||
926 | avg_acc_val = AverageMeter() | ||
927 | |||
923 | while True: | 928 | while True: |
924 | if len(auto_cycles) != 0: | 929 | if len(auto_cycles) != 0: |
925 | response = auto_cycles.pop(0) | 930 | response = auto_cycles.pop(0) |
@@ -977,7 +982,7 @@ def main(): | |||
977 | mid_point=args.lr_mid_point, | 982 | mid_point=args.lr_mid_point, |
978 | ) | 983 | ) |
979 | 984 | ||
980 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" | 985 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" |
981 | 986 | ||
982 | trainer( | 987 | trainer( |
983 | train_dataloader=datamodule.train_dataloader, | 988 | train_dataloader=datamodule.train_dataloader, |
@@ -994,6 +999,10 @@ def main(): | |||
994 | sample_frequency=sample_frequency, | 999 | sample_frequency=sample_frequency, |
995 | placeholder_tokens=placeholder_tokens, | 1000 | placeholder_tokens=placeholder_tokens, |
996 | placeholder_token_ids=placeholder_token_ids, | 1001 | placeholder_token_ids=placeholder_token_ids, |
1002 | avg_loss=avg_loss, | ||
1003 | avg_acc=avg_acc, | ||
1004 | avg_loss_val=avg_loss_val, | ||
1005 | avg_acc_val=avg_acc_val, | ||
997 | ) | 1006 | ) |
998 | 1007 | ||
999 | training_iter += 1 | 1008 | training_iter += 1 |
diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -461,6 +461,10 @@ def train_loop( | |||
461 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
462 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
463 | group_labels: list[str] = [], | 463 | group_labels: list[str] = [], |
464 | avg_loss: AverageMeter = AverageMeter(), | ||
465 | avg_acc: AverageMeter = AverageMeter(), | ||
466 | avg_loss_val: AverageMeter = AverageMeter(), | ||
467 | avg_acc_val: AverageMeter = AverageMeter(), | ||
464 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
465 | ): | 469 | ): |
466 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
@@ -472,14 +476,8 @@ def train_loop( | |||
472 | global_step = 0 | 476 | global_step = 0 |
473 | cache = {} | 477 | cache = {} |
474 | 478 | ||
475 | avg_loss = AverageMeter() | 479 | best_acc = avg_acc.avg |
476 | avg_acc = AverageMeter() | 480 | best_acc_val = avg_acc_val.avg |
477 | |||
478 | avg_loss_val = AverageMeter() | ||
479 | avg_acc_val = AverageMeter() | ||
480 | |||
481 | best_acc = 0.0 | ||
482 | best_acc_val = 0.0 | ||
483 | 481 | ||
484 | local_progress_bar = tqdm( | 482 | local_progress_bar = tqdm( |
485 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -544,12 +542,12 @@ def train_loop( | |||
544 | 542 | ||
545 | accelerator.backward(loss) | 543 | accelerator.backward(loss) |
546 | 544 | ||
547 | avg_loss.update(loss.detach_(), bsz) | 545 | avg_loss.update(loss.item(), bsz) |
548 | avg_acc.update(acc.detach_(), bsz) | 546 | avg_acc.update(acc.item(), bsz) |
549 | 547 | ||
550 | logs = { | 548 | logs = { |
551 | "train/loss": avg_loss.avg.item(), | 549 | "train/loss": avg_loss.avg, |
552 | "train/acc": avg_acc.avg.item(), | 550 | "train/acc": avg_acc.avg, |
553 | "train/cur_loss": loss.item(), | 551 | "train/cur_loss": loss.item(), |
554 | "train/cur_acc": acc.item(), | 552 | "train/cur_acc": acc.item(), |
555 | } | 553 | } |
@@ -603,47 +601,47 @@ def train_loop( | |||
603 | loss = loss.detach_() | 601 | loss = loss.detach_() |
604 | acc = acc.detach_() | 602 | acc = acc.detach_() |
605 | 603 | ||
606 | cur_loss_val.update(loss, bsz) | 604 | cur_loss_val.update(loss.item(), bsz) |
607 | cur_acc_val.update(acc, bsz) | 605 | cur_acc_val.update(acc.item(), bsz) |
608 | 606 | ||
609 | avg_loss_val.update(loss, bsz) | 607 | avg_loss_val.update(loss.item(), bsz) |
610 | avg_acc_val.update(acc, bsz) | 608 | avg_acc_val.update(acc.item(), bsz) |
611 | 609 | ||
612 | local_progress_bar.update(1) | 610 | local_progress_bar.update(1) |
613 | global_progress_bar.update(1) | 611 | global_progress_bar.update(1) |
614 | 612 | ||
615 | logs = { | 613 | logs = { |
616 | "val/loss": avg_loss_val.avg.item(), | 614 | "val/loss": avg_loss_val.avg, |
617 | "val/acc": avg_acc_val.avg.item(), | 615 | "val/acc": avg_acc_val.avg, |
618 | "val/cur_loss": loss.item(), | 616 | "val/cur_loss": loss.item(), |
619 | "val/cur_acc": acc.item(), | 617 | "val/cur_acc": acc.item(), |
620 | } | 618 | } |
621 | local_progress_bar.set_postfix(**logs) | 619 | local_progress_bar.set_postfix(**logs) |
622 | 620 | ||
623 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 621 | logs["val/cur_loss"] = cur_loss_val.avg |
624 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 622 | logs["val/cur_acc"] = cur_acc_val.avg |
625 | 623 | ||
626 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
627 | 625 | ||
628 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
629 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: |
630 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
631 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
632 | 630 | ||
633 | accelerator.print( | 631 | accelerator.print( |
634 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
635 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
636 | best_acc_val = avg_acc_val.avg.item() | 634 | best_acc_val = avg_acc_val.avg |
637 | else: | 635 | else: |
638 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
639 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 637 | if avg_acc.avg > best_acc and milestone_checkpoints: |
640 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
641 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
642 | 640 | ||
643 | accelerator.print( | 641 | accelerator.print( |
644 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
645 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
646 | best_acc = avg_acc.avg.item() | 644 | best_acc = avg_acc.avg |
647 | 645 | ||
648 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
649 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
@@ -688,6 +686,10 @@ def train( | |||
688 | offset_noise_strength: float = 0.15, | 686 | offset_noise_strength: float = 0.15, |
689 | disc: Optional[ConvNeXtDiscriminator] = None, | 687 | disc: Optional[ConvNeXtDiscriminator] = None, |
690 | min_snr_gamma: int = 5, | 688 | min_snr_gamma: int = 5, |
689 | avg_loss: AverageMeter = AverageMeter(), | ||
690 | avg_acc: AverageMeter = AverageMeter(), | ||
691 | avg_loss_val: AverageMeter = AverageMeter(), | ||
692 | avg_acc_val: AverageMeter = AverageMeter(), | ||
691 | **kwargs, | 693 | **kwargs, |
692 | ): | 694 | ): |
693 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 695 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
@@ -737,6 +739,10 @@ def train( | |||
737 | num_epochs=num_train_epochs, | 739 | num_epochs=num_train_epochs, |
738 | gradient_accumulation_steps=gradient_accumulation_steps, | 740 | gradient_accumulation_steps=gradient_accumulation_steps, |
739 | group_labels=group_labels, | 741 | group_labels=group_labels, |
742 | avg_loss=avg_loss, | ||
743 | avg_acc=avg_acc, | ||
744 | avg_loss_val=avg_loss_val, | ||
745 | avg_acc_val=avg_acc_val, | ||
740 | callbacks=callbacks, | 746 | callbacks=callbacks, |
741 | ) | 747 | ) |
742 | 748 | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -138,6 +138,14 @@ def lora_strategy_callbacks( | |||
138 | state_dict.update(text_encoder_state_dict) | 138 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
140 | 140 | ||
141 | if len(placeholder_tokens) != 0: | ||
142 | ti_state_dict = { | ||
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | ||
144 | for (token, ids) | ||
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | ||
147 | state_dict.update(ti_state_dict) | ||
148 | |||
141 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
142 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
143 | json.dump(lora_config, f) | 151 | json.dump(lora_config, f) |
diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): | |||
16 | 16 | ||
17 | 17 | ||
18 | class AverageMeter: | 18 | class AverageMeter: |
19 | avg: Any | 19 | def __init__(self, inv_gamma=1.0, power=2 / 3): |
20 | 20 | self.inv_gamma = inv_gamma | |
21 | def __init__(self, name=None): | 21 | self.power = power |
22 | self.name = name | ||
23 | self.reset() | 22 | self.reset() |
24 | 23 | ||
25 | def reset(self): | 24 | def reset(self): |
26 | self.sum = self.count = self.avg = 0 | 25 | self.step = 0 |
26 | self.avg = 0 | ||
27 | |||
28 | def get_decay(self): | ||
29 | if self.step <= 0: | ||
30 | return 1 | ||
31 | |||
32 | return (self.step / self.inv_gamma) ** -self.power | ||
27 | 33 | ||
28 | def update(self, val, n=1): | 34 | def update(self, val, n=1): |
29 | self.sum += val * n | 35 | for _ in range(n): |
30 | self.count += n | 36 | self.step += n |
31 | self.avg = self.sum / self.count | 37 | self.avg += (val - self.avg) * self.get_decay() |
32 | 38 | ||
33 | 39 | ||
34 | class EMAModel(EMAModel_): | 40 | class EMAModel(EMAModel_): |
diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py | |||
@@ -0,0 +1,24 @@ | |||
1 | from pathlib import Path | ||
2 | |||
3 | import torch | ||
4 | |||
5 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | ||
6 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
7 | |||
8 | |||
9 | def load_embeddings( | ||
10 | tokenizer: MultiCLIPTokenizer, | ||
11 | embeddings: ManagedCLIPTextEmbeddings, | ||
12 | tokens: list[str], | ||
13 | token_embeddings: torch.FloatTensor, | ||
14 | ): | ||
15 | num_vectors = [embedding.shape[0] for embedding in token_embeddings] | ||
16 | |||
17 | token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) | ||
18 | |||
19 | embeddings.resize(len(tokenizer)) | ||
20 | |||
21 | for (new_id, embeds) in zip(token_ids, token_embeddings): | ||
22 | embeddings.add_embed(new_id, embeds) | ||
23 | |||
24 | return tokens, token_ids | ||