from pathlib import Path import json import math from typing import Iterable, Any from contextlib import contextmanager import torch from diffusers.training_utils import EMAModel as EMAModel_ def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) class AverageMeter: def __init__(self, inv_gamma=1.0, power=2 / 3): self.inv_gamma = inv_gamma self.power = power self.reset() def reset(self): self.step = 0 self.min = math.inf self.max = 0.0 self.avg = 0.0 def get_decay(self): if self.step <= 0: return 1 return (self.step / self.inv_gamma) ** -self.power def update(self, val, n=1): for _ in range(n): self.step += n self.avg += (val - self.avg) * self.get_decay() self.min = min(self.min, self.avg) self.max = max(self.max, self.avg) class EMAModel(EMAModel_): @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): parameters = list(parameters) original_params = [p.clone() for p in parameters] self.copy_to(parameters) try: yield finally: for o_param, param in zip(original_params, parameters): param.data.copy_(o_param.data) del original_params