import math from typing import Dict, Optional, Iterable, List, Tuple, Any import copy import torch import numpy as np from torch import nn from functorch import make_functional, make_functional_with_buffers from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.modeling_utils import ModelMixin def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int): weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings) if weight_chunk_dim != 0: remainder = num_target_parameters % weight_chunk_dim if remainder > 0: diff = math.ceil(remainder / weight_chunk_dim) num_embeddings += diff return weight_chunk_dim def count_params(target: ModelMixin): return sum([np.prod(p.size()) for p in target.parameters()]) class FunctionalParamVectorWrapper(ModelMixin): """ This wraps a module so that it takes params in the forward pass """ def __init__(self, module: ModelMixin): super().__init__() self.custom_buffers = None param_dict = dict(module.named_parameters()) self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict} try: _functional, self.named_params = make_functional(module) except Exception: _functional, self.named_params, buffers = make_functional_with_buffers( module ) self.custom_buffers = buffers self.functional = [_functional] # remove params from being counted def forward(self, param_vector: torch.Tensor, *args, **kwargs): params = [] start = 0 for p in self.named_params: end = start + np.prod(p.size()) params.append(param_vector[start:end].view(p.size())) start = end if self.custom_buffers is not None: return self.functional[0](params, self.custom_buffers, *args, **kwargs) return self.functional[0](params, *args, **kwargs) class Hypernetwork(ModelMixin, ConfigMixin): @register_to_config def __init__( self, target_network: ModelMixin, num_target_parameters: Optional[int] = None, embedding_dim: int = 100, num_embeddings: int = 3, weight_chunk_dim: Optional[int] = None, ): super().__init__() self._target = FunctionalParamVectorWrapper(target_network) self.target_weight_shapes = self._target.target_weight_shapes self.num_target_parameters = num_target_parameters self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.weight_chunk_dim = weight_chunk_dim self.embedding_module = self.make_embedding_module() self.weight_generator = self.make_weight_generator() def make_embedding_module(self) -> nn.Module: return nn.Embedding(self.num_embeddings, self.embedding_dim) def make_weight_generator(self) -> nn.Module: return nn.Linear(self.embedding_dim, self.weight_chunk_dim) def generate_params( self, inp: Iterable[Any] = [] ) -> Tuple[torch.Tensor, Dict[str, Any]]: embedding = self.embedding_module( torch.arange(self.num_embeddings, device=self.device) ) generated_params = self.weight_generator(embedding).view(-1) return generated_params, {"embedding": embedding} def forward( self, inp: Iterable[Any] = [], *args, **kwargs, ): generated_params, aux_output = self.generate_params(inp, *args, **kwargs) assert generated_params.shape[-1] >= self.num_target_parameters return self._target(generated_params, *inp) @property def device(self) -> torch.device: return self._target.device @classmethod def from_target( cls, target_network: ModelMixin, num_target_parameters: Optional[int] = None, embedding_dim: int = 8, num_embeddings: int = 3, weight_chunk_dim: Optional[int] = None, *args, **kwargs, ): if num_target_parameters is None: num_target_parameters = count_params(target_network) if weight_chunk_dim is None: weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings) return cls( target_network=target_network, num_target_parameters=num_target_parameters, embedding_dim=embedding_dim, num_embeddings=num_embeddings, weight_chunk_dim=weight_chunk_dim, *args, **kwargs, )