From 0885c6eb569ce304fa7c8d86bf50f3deeaaa22e3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Oct 2022 12:47:25 +0200 Subject: Remove unused code --- models/hypernetwork.py | 138 ------------------------------------------------- 1 file changed, 138 deletions(-) delete mode 100644 models/hypernetwork.py (limited to 'models') diff --git a/models/hypernetwork.py b/models/hypernetwork.py deleted file mode 100644 index fe8a312..0000000 --- a/models/hypernetwork.py +++ /dev/null @@ -1,138 +0,0 @@ -import math -from typing import Dict, Optional, Iterable, List, Tuple, Any -import copy -import torch -import numpy as np -from torch import nn -from functorch import make_functional, make_functional_with_buffers -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin - - -def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int): - weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings) - if weight_chunk_dim != 0: - remainder = num_target_parameters % weight_chunk_dim - if remainder > 0: - diff = math.ceil(remainder / weight_chunk_dim) - num_embeddings += diff - return weight_chunk_dim - - -def count_params(target: ModelMixin): - return sum([np.prod(p.size()) for p in target.parameters()]) - - -class FunctionalParamVectorWrapper(ModelMixin): - """ - This wraps a module so that it takes params in the forward pass - """ - - def __init__(self, module: ModelMixin): - super().__init__() - - self.custom_buffers = None - param_dict = dict(module.named_parameters()) - self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict} - - try: - _functional, self.named_params = make_functional(module) - except Exception: - _functional, self.named_params, buffers = make_functional_with_buffers( - module - ) - self.custom_buffers = buffers - self.functional = [_functional] # remove params from being counted - - def forward(self, param_vector: torch.Tensor, *args, **kwargs): - params = [] - start = 0 - for p in self.named_params: - end = start + np.prod(p.size()) - params.append(param_vector[start:end].view(p.size())) - start = end - if self.custom_buffers is not None: - return self.functional[0](params, self.custom_buffers, *args, **kwargs) - return self.functional[0](params, *args, **kwargs) - - -class Hypernetwork(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - target_network: ModelMixin, - num_target_parameters: Optional[int] = None, - embedding_dim: int = 100, - num_embeddings: int = 3, - weight_chunk_dim: Optional[int] = None, - ): - super().__init__() - - self._target = FunctionalParamVectorWrapper(target_network) - - self.target_weight_shapes = self._target.target_weight_shapes - - self.num_target_parameters = num_target_parameters - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.weight_chunk_dim = weight_chunk_dim - - self.embedding_module = self.make_embedding_module() - self.weight_generator = self.make_weight_generator() - - def make_embedding_module(self) -> nn.Module: - return nn.Embedding(self.num_embeddings, self.embedding_dim) - - def make_weight_generator(self) -> nn.Module: - return nn.Linear(self.embedding_dim, self.weight_chunk_dim) - - def generate_params( - self, inp: Iterable[Any] = [] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: - embedding = self.embedding_module( - torch.arange(self.num_embeddings, device=self.device) - ) - generated_params = self.weight_generator(embedding).view(-1) - return generated_params, {"embedding": embedding} - - def forward( - self, - inp: Iterable[Any] = [], - *args, - **kwargs, - ): - generated_params, aux_output = self.generate_params(inp, *args, **kwargs) - - assert generated_params.shape[-1] >= self.num_target_parameters - - return self._target(generated_params, *inp) - - @property - def device(self) -> torch.device: - return self._target.device - - @classmethod - def from_target( - cls, - target_network: ModelMixin, - num_target_parameters: Optional[int] = None, - embedding_dim: int = 8, - num_embeddings: int = 3, - weight_chunk_dim: Optional[int] = None, - *args, - **kwargs, - ): - if num_target_parameters is None: - num_target_parameters = count_params(target_network) - if weight_chunk_dim is None: - weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings) - return cls( - target_network=target_network, - num_target_parameters=num_target_parameters, - embedding_dim=embedding_dim, - num_embeddings=num_embeddings, - weight_chunk_dim=weight_chunk_dim, - *args, - **kwargs, - ) -- cgit v1.2.3-54-g00ecf