From 1eef9a946161fd06b0e72ec804c68f4f0e74b380 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Oct 2022 12:42:21 +0200 Subject: Update --- models/hypernetwork.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 models/hypernetwork.py (limited to 'models') diff --git a/models/hypernetwork.py b/models/hypernetwork.py new file mode 100644 index 0000000..fe8a312 --- /dev/null +++ b/models/hypernetwork.py @@ -0,0 +1,138 @@ +import math +from typing import Dict, Optional, Iterable, List, Tuple, Any +import copy +import torch +import numpy as np +from torch import nn +from functorch import make_functional, make_functional_with_buffers +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin + + +def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int): + weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings) + if weight_chunk_dim != 0: + remainder = num_target_parameters % weight_chunk_dim + if remainder > 0: + diff = math.ceil(remainder / weight_chunk_dim) + num_embeddings += diff + return weight_chunk_dim + + +def count_params(target: ModelMixin): + return sum([np.prod(p.size()) for p in target.parameters()]) + + +class FunctionalParamVectorWrapper(ModelMixin): + """ + This wraps a module so that it takes params in the forward pass + """ + + def __init__(self, module: ModelMixin): + super().__init__() + + self.custom_buffers = None + param_dict = dict(module.named_parameters()) + self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict} + + try: + _functional, self.named_params = make_functional(module) + except Exception: + _functional, self.named_params, buffers = make_functional_with_buffers( + module + ) + self.custom_buffers = buffers + self.functional = [_functional] # remove params from being counted + + def forward(self, param_vector: torch.Tensor, *args, **kwargs): + params = [] + start = 0 + for p in self.named_params: + end = start + np.prod(p.size()) + params.append(param_vector[start:end].view(p.size())) + start = end + if self.custom_buffers is not None: + return self.functional[0](params, self.custom_buffers, *args, **kwargs) + return self.functional[0](params, *args, **kwargs) + + +class Hypernetwork(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + target_network: ModelMixin, + num_target_parameters: Optional[int] = None, + embedding_dim: int = 100, + num_embeddings: int = 3, + weight_chunk_dim: Optional[int] = None, + ): + super().__init__() + + self._target = FunctionalParamVectorWrapper(target_network) + + self.target_weight_shapes = self._target.target_weight_shapes + + self.num_target_parameters = num_target_parameters + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.weight_chunk_dim = weight_chunk_dim + + self.embedding_module = self.make_embedding_module() + self.weight_generator = self.make_weight_generator() + + def make_embedding_module(self) -> nn.Module: + return nn.Embedding(self.num_embeddings, self.embedding_dim) + + def make_weight_generator(self) -> nn.Module: + return nn.Linear(self.embedding_dim, self.weight_chunk_dim) + + def generate_params( + self, inp: Iterable[Any] = [] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + embedding = self.embedding_module( + torch.arange(self.num_embeddings, device=self.device) + ) + generated_params = self.weight_generator(embedding).view(-1) + return generated_params, {"embedding": embedding} + + def forward( + self, + inp: Iterable[Any] = [], + *args, + **kwargs, + ): + generated_params, aux_output = self.generate_params(inp, *args, **kwargs) + + assert generated_params.shape[-1] >= self.num_target_parameters + + return self._target(generated_params, *inp) + + @property + def device(self) -> torch.device: + return self._target.device + + @classmethod + def from_target( + cls, + target_network: ModelMixin, + num_target_parameters: Optional[int] = None, + embedding_dim: int = 8, + num_embeddings: int = 3, + weight_chunk_dim: Optional[int] = None, + *args, + **kwargs, + ): + if num_target_parameters is None: + num_target_parameters = count_params(target_network) + if weight_chunk_dim is None: + weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings) + return cls( + target_network=target_network, + num_target_parameters=num_target_parameters, + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + weight_chunk_dim=weight_chunk_dim, + *args, + **kwargs, + ) -- cgit v1.2.3-54-g00ecf