diff options
author | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-27 07:47:59 +0200 |
commit | 6d46bf79bd7710cea799fbfe27c12d06d12cd53f (patch) | |
tree | 6c65817b9351453bfb5366f7010f8d87659c0dd0 /training | |
parent | Fix cycle loop (diff) | |
download | textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.gz textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.tar.bz2 textual-inversion-diff-6d46bf79bd7710cea799fbfe27c12d06d12cd53f.zip |
Update
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 58 | ||||
-rw-r--r-- | training/strategy/lora.py | 8 | ||||
-rw-r--r-- | training/util.py | 22 |
3 files changed, 54 insertions, 34 deletions
diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -461,6 +461,10 @@ def train_loop( | |||
461 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
462 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
463 | group_labels: list[str] = [], | 463 | group_labels: list[str] = [], |
464 | avg_loss: AverageMeter = AverageMeter(), | ||
465 | avg_acc: AverageMeter = AverageMeter(), | ||
466 | avg_loss_val: AverageMeter = AverageMeter(), | ||
467 | avg_acc_val: AverageMeter = AverageMeter(), | ||
464 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 468 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
465 | ): | 469 | ): |
466 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 470 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) |
@@ -472,14 +476,8 @@ def train_loop( | |||
472 | global_step = 0 | 476 | global_step = 0 |
473 | cache = {} | 477 | cache = {} |
474 | 478 | ||
475 | avg_loss = AverageMeter() | 479 | best_acc = avg_acc.avg |
476 | avg_acc = AverageMeter() | 480 | best_acc_val = avg_acc_val.avg |
477 | |||
478 | avg_loss_val = AverageMeter() | ||
479 | avg_acc_val = AverageMeter() | ||
480 | |||
481 | best_acc = 0.0 | ||
482 | best_acc_val = 0.0 | ||
483 | 481 | ||
484 | local_progress_bar = tqdm( | 482 | local_progress_bar = tqdm( |
485 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 483 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
@@ -544,12 +542,12 @@ def train_loop( | |||
544 | 542 | ||
545 | accelerator.backward(loss) | 543 | accelerator.backward(loss) |
546 | 544 | ||
547 | avg_loss.update(loss.detach_(), bsz) | 545 | avg_loss.update(loss.item(), bsz) |
548 | avg_acc.update(acc.detach_(), bsz) | 546 | avg_acc.update(acc.item(), bsz) |
549 | 547 | ||
550 | logs = { | 548 | logs = { |
551 | "train/loss": avg_loss.avg.item(), | 549 | "train/loss": avg_loss.avg, |
552 | "train/acc": avg_acc.avg.item(), | 550 | "train/acc": avg_acc.avg, |
553 | "train/cur_loss": loss.item(), | 551 | "train/cur_loss": loss.item(), |
554 | "train/cur_acc": acc.item(), | 552 | "train/cur_acc": acc.item(), |
555 | } | 553 | } |
@@ -603,47 +601,47 @@ def train_loop( | |||
603 | loss = loss.detach_() | 601 | loss = loss.detach_() |
604 | acc = acc.detach_() | 602 | acc = acc.detach_() |
605 | 603 | ||
606 | cur_loss_val.update(loss, bsz) | 604 | cur_loss_val.update(loss.item(), bsz) |
607 | cur_acc_val.update(acc, bsz) | 605 | cur_acc_val.update(acc.item(), bsz) |
608 | 606 | ||
609 | avg_loss_val.update(loss, bsz) | 607 | avg_loss_val.update(loss.item(), bsz) |
610 | avg_acc_val.update(acc, bsz) | 608 | avg_acc_val.update(acc.item(), bsz) |
611 | 609 | ||
612 | local_progress_bar.update(1) | 610 | local_progress_bar.update(1) |
613 | global_progress_bar.update(1) | 611 | global_progress_bar.update(1) |
614 | 612 | ||
615 | logs = { | 613 | logs = { |
616 | "val/loss": avg_loss_val.avg.item(), | 614 | "val/loss": avg_loss_val.avg, |
617 | "val/acc": avg_acc_val.avg.item(), | 615 | "val/acc": avg_acc_val.avg, |
618 | "val/cur_loss": loss.item(), | 616 | "val/cur_loss": loss.item(), |
619 | "val/cur_acc": acc.item(), | 617 | "val/cur_acc": acc.item(), |
620 | } | 618 | } |
621 | local_progress_bar.set_postfix(**logs) | 619 | local_progress_bar.set_postfix(**logs) |
622 | 620 | ||
623 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 621 | logs["val/cur_loss"] = cur_loss_val.avg |
624 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 622 | logs["val/cur_acc"] = cur_acc_val.avg |
625 | 623 | ||
626 | accelerator.log(logs, step=global_step) | 624 | accelerator.log(logs, step=global_step) |
627 | 625 | ||
628 | if accelerator.is_main_process: | 626 | if accelerator.is_main_process: |
629 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: | 627 | if avg_acc_val.avg > best_acc_val and milestone_checkpoints: |
630 | local_progress_bar.clear() | 628 | local_progress_bar.clear() |
631 | global_progress_bar.clear() | 629 | global_progress_bar.clear() |
632 | 630 | ||
633 | accelerator.print( | 631 | accelerator.print( |
634 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 632 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") |
635 | on_checkpoint(global_step, "milestone") | 633 | on_checkpoint(global_step, "milestone") |
636 | best_acc_val = avg_acc_val.avg.item() | 634 | best_acc_val = avg_acc_val.avg |
637 | else: | 635 | else: |
638 | if accelerator.is_main_process: | 636 | if accelerator.is_main_process: |
639 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: | 637 | if avg_acc.avg > best_acc and milestone_checkpoints: |
640 | local_progress_bar.clear() | 638 | local_progress_bar.clear() |
641 | global_progress_bar.clear() | 639 | global_progress_bar.clear() |
642 | 640 | ||
643 | accelerator.print( | 641 | accelerator.print( |
644 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 642 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") |
645 | on_checkpoint(global_step, "milestone") | 643 | on_checkpoint(global_step, "milestone") |
646 | best_acc = avg_acc.avg.item() | 644 | best_acc = avg_acc.avg |
647 | 645 | ||
648 | # Create the pipeline using using the trained modules and save it. | 646 | # Create the pipeline using using the trained modules and save it. |
649 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |
@@ -688,6 +686,10 @@ def train( | |||
688 | offset_noise_strength: float = 0.15, | 686 | offset_noise_strength: float = 0.15, |
689 | disc: Optional[ConvNeXtDiscriminator] = None, | 687 | disc: Optional[ConvNeXtDiscriminator] = None, |
690 | min_snr_gamma: int = 5, | 688 | min_snr_gamma: int = 5, |
689 | avg_loss: AverageMeter = AverageMeter(), | ||
690 | avg_acc: AverageMeter = AverageMeter(), | ||
691 | avg_loss_val: AverageMeter = AverageMeter(), | ||
692 | avg_acc_val: AverageMeter = AverageMeter(), | ||
691 | **kwargs, | 693 | **kwargs, |
692 | ): | 694 | ): |
693 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 695 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( |
@@ -737,6 +739,10 @@ def train( | |||
737 | num_epochs=num_train_epochs, | 739 | num_epochs=num_train_epochs, |
738 | gradient_accumulation_steps=gradient_accumulation_steps, | 740 | gradient_accumulation_steps=gradient_accumulation_steps, |
739 | group_labels=group_labels, | 741 | group_labels=group_labels, |
742 | avg_loss=avg_loss, | ||
743 | avg_acc=avg_acc, | ||
744 | avg_loss_val=avg_loss_val, | ||
745 | avg_acc_val=avg_acc_val, | ||
740 | callbacks=callbacks, | 746 | callbacks=callbacks, |
741 | ) | 747 | ) |
742 | 748 | ||
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -138,6 +138,14 @@ def lora_strategy_callbacks( | |||
138 | state_dict.update(text_encoder_state_dict) | 138 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) |
140 | 140 | ||
141 | if len(placeholder_tokens) != 0: | ||
142 | ti_state_dict = { | ||
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | ||
144 | for (token, ids) | ||
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | ||
147 | state_dict.update(ti_state_dict) | ||
148 | |||
141 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") |
142 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
143 | json.dump(lora_config, f) | 151 | json.dump(lora_config, f) |
diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): | |||
16 | 16 | ||
17 | 17 | ||
18 | class AverageMeter: | 18 | class AverageMeter: |
19 | avg: Any | 19 | def __init__(self, inv_gamma=1.0, power=2 / 3): |
20 | 20 | self.inv_gamma = inv_gamma | |
21 | def __init__(self, name=None): | 21 | self.power = power |
22 | self.name = name | ||
23 | self.reset() | 22 | self.reset() |
24 | 23 | ||
25 | def reset(self): | 24 | def reset(self): |
26 | self.sum = self.count = self.avg = 0 | 25 | self.step = 0 |
26 | self.avg = 0 | ||
27 | |||
28 | def get_decay(self): | ||
29 | if self.step <= 0: | ||
30 | return 1 | ||
31 | |||
32 | return (self.step / self.inv_gamma) ** -self.power | ||
27 | 33 | ||
28 | def update(self, val, n=1): | 34 | def update(self, val, n=1): |
29 | self.sum += val * n | 35 | for _ in range(n): |
30 | self.count += n | 36 | self.step += n |
31 | self.avg = self.sum / self.count | 37 | self.avg += (val - self.avg) * self.get_decay() |
32 | 38 | ||
33 | 39 | ||
34 | class EMAModel(EMAModel_): | 40 | class EMAModel(EMAModel_): |