From 6b58e9de249e872bd2d83e5916e6c633f52cfbb8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 31 Dec 2022 12:58:54 +0100 Subject: Added multi-vector embeddings --- training/util.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 4 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index d0f7fcd..43a55e1 100644 --- a/training/util.py +++ b/training/util.py @@ -1,5 +1,6 @@ from pathlib import Path import json +from typing import Iterable import torch from PIL import Image @@ -39,8 +40,6 @@ class CheckpointerBase: self, datamodule, output_dir: Path, - placeholder_token, - placeholder_token_id, sample_image_size, sample_batches, sample_batch_size, @@ -48,8 +47,6 @@ class CheckpointerBase: ): self.datamodule = datamodule self.output_dir = output_dir - self.placeholder_token = placeholder_token - self.placeholder_token_id = placeholder_token_id self.sample_image_size = sample_image_size self.seed = seed or torch.random.seed() self.sample_batches = sample_batches @@ -117,3 +114,58 @@ class CheckpointerBase: del image_grid del generator + + +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + self.decay = decay + self.optimization_step = 0 + + @torch.no_grad() + def step(self, parameters): + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + value = (1 + self.optimization_step) / (10 + self.optimization_step) + one_minus_decay = 1 - min(self.decay, value) + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] -- cgit v1.2.3-54-g00ecf