From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- training/functional.py | 58 ++++++++++++++++++++++++++--------------------- training/strategy/lora.py | 8 +++++++ training/util.py | 22 +++++++++++------- 3 files changed, 54 insertions(+), 34 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -461,6 +461,10 @@ def train_loop( num_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], + avg_loss: AverageMeter = AverageMeter(), + avg_acc: AverageMeter = AverageMeter(), + avg_loss_val: AverageMeter = AverageMeter(), + avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) @@ -472,14 +476,8 @@ def train_loop( global_step = 0 cache = {} - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - best_acc = 0.0 - best_acc_val = 0.0 + best_acc = avg_acc.avg + best_acc_val = avg_acc_val.avg local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -544,12 +542,12 @@ def train_loop( accelerator.backward(loss) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.item(), bsz) + avg_acc.update(acc.item(), bsz) logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), + "train/loss": avg_loss.avg, + "train/acc": avg_acc.avg, "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), } @@ -603,47 +601,47 @@ def train_loop( loss = loss.detach_() acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss.item(), bsz) + cur_acc_val.update(acc.item(), bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss.item(), bsz) + avg_acc_val.update(acc.item(), bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), + "val/loss": avg_loss_val.avg, + "val/acc": avg_acc_val.avg, "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() + logs["val/cur_loss"] = cur_loss_val.avg + logs["val/cur_acc"] = cur_acc_val.avg accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: + if avg_acc_val.avg > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc_val = avg_acc_val.avg.item() + best_acc_val = avg_acc_val.avg else: if accelerator.is_main_process: - if avg_acc.avg.item() > best_acc and milestone_checkpoints: + if avg_acc.avg > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc = avg_acc.avg.item() + best_acc = avg_acc.avg # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -688,6 +686,10 @@ def train( offset_noise_strength: float = 0.15, disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, + avg_loss: AverageMeter = AverageMeter(), + avg_acc: AverageMeter = AverageMeter(), + avg_loss_val: AverageMeter = AverageMeter(), + avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( @@ -737,6 +739,10 @@ def train( num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, group_labels=group_labels, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, callbacks=callbacks, ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -138,6 +138,14 @@ def lora_strategy_callbacks( state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + if len(placeholder_tokens) != 0: + ti_state_dict = { + f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) + for (token, ids) + in zip(placeholder_tokens, placeholder_token_ids) + } + state_dict.update(ti_state_dict) + save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py @@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): class AverageMeter: - avg: Any - - def __init__(self, name=None): - self.name = name + def __init__(self, inv_gamma=1.0, power=2 / 3): + self.inv_gamma = inv_gamma + self.power = power self.reset() def reset(self): - self.sum = self.count = self.avg = 0 + self.step = 0 + self.avg = 0 + + def get_decay(self): + if self.step <= 0: + return 1 + + return (self.step / self.inv_gamma) ** -self.power def update(self, val, n=1): - self.sum += val * n - self.count += n - self.avg = self.sum / self.count + for _ in range(n): + self.step += n + self.avg += (val - self.avg) * self.get_decay() class EMAModel(EMAModel_): -- cgit v1.2.3-70-g09d2