summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py58
-rw-r--r--training/strategy/lora.py8
-rw-r--r--training/util.py22
3 files changed, 54 insertions, 34 deletions
diff --git a/training/functional.py b/training/functional.py
index 695a24f..3036ed9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -461,6 +461,10 @@ def train_loop(
461 num_epochs: int = 100, 461 num_epochs: int = 100,
462 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
463 group_labels: list[str] = [], 463 group_labels: list[str] = [],
464 avg_loss: AverageMeter = AverageMeter(),
465 avg_acc: AverageMeter = AverageMeter(),
466 avg_loss_val: AverageMeter = AverageMeter(),
467 avg_acc_val: AverageMeter = AverageMeter(),
464 callbacks: TrainingCallbacks = TrainingCallbacks(), 468 callbacks: TrainingCallbacks = TrainingCallbacks(),
465): 469):
466 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 470 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
@@ -472,14 +476,8 @@ def train_loop(
472 global_step = 0 476 global_step = 0
473 cache = {} 477 cache = {}
474 478
475 avg_loss = AverageMeter() 479 best_acc = avg_acc.avg
476 avg_acc = AverageMeter() 480 best_acc_val = avg_acc_val.avg
477
478 avg_loss_val = AverageMeter()
479 avg_acc_val = AverageMeter()
480
481 best_acc = 0.0
482 best_acc_val = 0.0
483 481
484 local_progress_bar = tqdm( 482 local_progress_bar = tqdm(
485 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 483 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
@@ -544,12 +542,12 @@ def train_loop(
544 542
545 accelerator.backward(loss) 543 accelerator.backward(loss)
546 544
547 avg_loss.update(loss.detach_(), bsz) 545 avg_loss.update(loss.item(), bsz)
548 avg_acc.update(acc.detach_(), bsz) 546 avg_acc.update(acc.item(), bsz)
549 547
550 logs = { 548 logs = {
551 "train/loss": avg_loss.avg.item(), 549 "train/loss": avg_loss.avg,
552 "train/acc": avg_acc.avg.item(), 550 "train/acc": avg_acc.avg,
553 "train/cur_loss": loss.item(), 551 "train/cur_loss": loss.item(),
554 "train/cur_acc": acc.item(), 552 "train/cur_acc": acc.item(),
555 } 553 }
@@ -603,47 +601,47 @@ def train_loop(
603 loss = loss.detach_() 601 loss = loss.detach_()
604 acc = acc.detach_() 602 acc = acc.detach_()
605 603
606 cur_loss_val.update(loss, bsz) 604 cur_loss_val.update(loss.item(), bsz)
607 cur_acc_val.update(acc, bsz) 605 cur_acc_val.update(acc.item(), bsz)
608 606
609 avg_loss_val.update(loss, bsz) 607 avg_loss_val.update(loss.item(), bsz)
610 avg_acc_val.update(acc, bsz) 608 avg_acc_val.update(acc.item(), bsz)
611 609
612 local_progress_bar.update(1) 610 local_progress_bar.update(1)
613 global_progress_bar.update(1) 611 global_progress_bar.update(1)
614 612
615 logs = { 613 logs = {
616 "val/loss": avg_loss_val.avg.item(), 614 "val/loss": avg_loss_val.avg,
617 "val/acc": avg_acc_val.avg.item(), 615 "val/acc": avg_acc_val.avg,
618 "val/cur_loss": loss.item(), 616 "val/cur_loss": loss.item(),
619 "val/cur_acc": acc.item(), 617 "val/cur_acc": acc.item(),
620 } 618 }
621 local_progress_bar.set_postfix(**logs) 619 local_progress_bar.set_postfix(**logs)
622 620
623 logs["val/cur_loss"] = cur_loss_val.avg.item() 621 logs["val/cur_loss"] = cur_loss_val.avg
624 logs["val/cur_acc"] = cur_acc_val.avg.item() 622 logs["val/cur_acc"] = cur_acc_val.avg
625 623
626 accelerator.log(logs, step=global_step) 624 accelerator.log(logs, step=global_step)
627 625
628 if accelerator.is_main_process: 626 if accelerator.is_main_process:
629 if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: 627 if avg_acc_val.avg > best_acc_val and milestone_checkpoints:
630 local_progress_bar.clear() 628 local_progress_bar.clear()
631 global_progress_bar.clear() 629 global_progress_bar.clear()
632 630
633 accelerator.print( 631 accelerator.print(
634 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 632 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}")
635 on_checkpoint(global_step, "milestone") 633 on_checkpoint(global_step, "milestone")
636 best_acc_val = avg_acc_val.avg.item() 634 best_acc_val = avg_acc_val.avg
637 else: 635 else:
638 if accelerator.is_main_process: 636 if accelerator.is_main_process:
639 if avg_acc.avg.item() > best_acc and milestone_checkpoints: 637 if avg_acc.avg > best_acc and milestone_checkpoints:
640 local_progress_bar.clear() 638 local_progress_bar.clear()
641 global_progress_bar.clear() 639 global_progress_bar.clear()
642 640
643 accelerator.print( 641 accelerator.print(
644 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 642 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}")
645 on_checkpoint(global_step, "milestone") 643 on_checkpoint(global_step, "milestone")
646 best_acc = avg_acc.avg.item() 644 best_acc = avg_acc.avg
647 645
648 # Create the pipeline using using the trained modules and save it. 646 # Create the pipeline using using the trained modules and save it.
649 if accelerator.is_main_process: 647 if accelerator.is_main_process:
@@ -688,6 +686,10 @@ def train(
688 offset_noise_strength: float = 0.15, 686 offset_noise_strength: float = 0.15,
689 disc: Optional[ConvNeXtDiscriminator] = None, 687 disc: Optional[ConvNeXtDiscriminator] = None,
690 min_snr_gamma: int = 5, 688 min_snr_gamma: int = 5,
689 avg_loss: AverageMeter = AverageMeter(),
690 avg_acc: AverageMeter = AverageMeter(),
691 avg_loss_val: AverageMeter = AverageMeter(),
692 avg_acc_val: AverageMeter = AverageMeter(),
691 **kwargs, 693 **kwargs,
692): 694):
693 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 695 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare(
@@ -737,6 +739,10 @@ def train(
737 num_epochs=num_train_epochs, 739 num_epochs=num_train_epochs,
738 gradient_accumulation_steps=gradient_accumulation_steps, 740 gradient_accumulation_steps=gradient_accumulation_steps,
739 group_labels=group_labels, 741 group_labels=group_labels,
742 avg_loss=avg_loss,
743 avg_acc=avg_acc,
744 avg_loss_val=avg_loss_val,
745 avg_acc_val=avg_acc_val,
740 callbacks=callbacks, 746 callbacks=callbacks,
741 ) 747 )
742 748
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 1f0a117..3f4dbbc 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -138,6 +138,14 @@ def lora_strategy_callbacks(
138 state_dict.update(text_encoder_state_dict) 138 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True)
140 140
141 if len(placeholder_tokens) != 0:
142 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 }
147 state_dict.update(ti_state_dict)
148
141 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors")
142 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 150 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
143 json.dump(lora_config, f) 151 json.dump(lora_config, f)
diff --git a/training/util.py b/training/util.py
index 8bd8a83..61f1533 100644
--- a/training/util.py
+++ b/training/util.py
@@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}):
16 16
17 17
18class AverageMeter: 18class AverageMeter:
19 avg: Any 19 def __init__(self, inv_gamma=1.0, power=2 / 3):
20 20 self.inv_gamma = inv_gamma
21 def __init__(self, name=None): 21 self.power = power
22 self.name = name
23 self.reset() 22 self.reset()
24 23
25 def reset(self): 24 def reset(self):
26 self.sum = self.count = self.avg = 0 25 self.step = 0
26 self.avg = 0
27
28 def get_decay(self):
29 if self.step <= 0:
30 return 1
31
32 return (self.step / self.inv_gamma) ** -self.power
27 33
28 def update(self, val, n=1): 34 def update(self, val, n=1):
29 self.sum += val * n 35 for _ in range(n):
30 self.count += n 36 self.step += n
31 self.avg = self.sum / self.count 37 self.avg += (val - self.avg) * self.get_decay()
32 38
33 39
34class EMAModel(EMAModel_): 40class EMAModel(EMAModel_):