from pathlib import Path import json from typing import Iterable, Any from contextlib import contextmanager import torch from diffusers.training_utils import EMAModel as EMAModel_ def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) class AverageMeter: avg: Any def __init__(self, name=None): self.name = name self.reset() def reset(self): self.sum = self.count = self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count class EMAModel(EMAModel_): @contextmanager def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): parameters = list(parameters) original_params = [p.clone() for p in parameters] self.copy_to(parameters) try: yield finally: for o_param, param in zip(original_params, parameters): param.data.copy_(o_param.data) del original_params