summaryrefslogtreecommitdiffstats
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/hypernetwork.py138
1 files changed, 0 insertions, 138 deletions
diff --git a/models/hypernetwork.py b/models/hypernetwork.py
deleted file mode 100644
index fe8a312..0000000
--- a/models/hypernetwork.py
+++ /dev/null
@@ -1,138 +0,0 @@
1import math
2from typing import Dict, Optional, Iterable, List, Tuple, Any
3import copy
4import torch
5import numpy as np
6from torch import nn
7from functorch import make_functional, make_functional_with_buffers
8from diffusers.configuration_utils import ConfigMixin, register_to_config
9from diffusers.modeling_utils import ModelMixin
10
11
12def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int):
13 weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings)
14 if weight_chunk_dim != 0:
15 remainder = num_target_parameters % weight_chunk_dim
16 if remainder > 0:
17 diff = math.ceil(remainder / weight_chunk_dim)
18 num_embeddings += diff
19 return weight_chunk_dim
20
21
22def count_params(target: ModelMixin):
23 return sum([np.prod(p.size()) for p in target.parameters()])
24
25
26class FunctionalParamVectorWrapper(ModelMixin):
27 """
28 This wraps a module so that it takes params in the forward pass
29 """
30
31 def __init__(self, module: ModelMixin):
32 super().__init__()
33
34 self.custom_buffers = None
35 param_dict = dict(module.named_parameters())
36 self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict}
37
38 try:
39 _functional, self.named_params = make_functional(module)
40 except Exception:
41 _functional, self.named_params, buffers = make_functional_with_buffers(
42 module
43 )
44 self.custom_buffers = buffers
45 self.functional = [_functional] # remove params from being counted
46
47 def forward(self, param_vector: torch.Tensor, *args, **kwargs):
48 params = []
49 start = 0
50 for p in self.named_params:
51 end = start + np.prod(p.size())
52 params.append(param_vector[start:end].view(p.size()))
53 start = end
54 if self.custom_buffers is not None:
55 return self.functional[0](params, self.custom_buffers, *args, **kwargs)
56 return self.functional[0](params, *args, **kwargs)
57
58
59class Hypernetwork(ModelMixin, ConfigMixin):
60 @register_to_config
61 def __init__(
62 self,
63 target_network: ModelMixin,
64 num_target_parameters: Optional[int] = None,
65 embedding_dim: int = 100,
66 num_embeddings: int = 3,
67 weight_chunk_dim: Optional[int] = None,
68 ):
69 super().__init__()
70
71 self._target = FunctionalParamVectorWrapper(target_network)
72
73 self.target_weight_shapes = self._target.target_weight_shapes
74
75 self.num_target_parameters = num_target_parameters
76
77 self.embedding_dim = embedding_dim
78 self.num_embeddings = num_embeddings
79 self.weight_chunk_dim = weight_chunk_dim
80
81 self.embedding_module = self.make_embedding_module()
82 self.weight_generator = self.make_weight_generator()
83
84 def make_embedding_module(self) -> nn.Module:
85 return nn.Embedding(self.num_embeddings, self.embedding_dim)
86
87 def make_weight_generator(self) -> nn.Module:
88 return nn.Linear(self.embedding_dim, self.weight_chunk_dim)
89
90 def generate_params(
91 self, inp: Iterable[Any] = []
92 ) -> Tuple[torch.Tensor, Dict[str, Any]]:
93 embedding = self.embedding_module(
94 torch.arange(self.num_embeddings, device=self.device)
95 )
96 generated_params = self.weight_generator(embedding).view(-1)
97 return generated_params, {"embedding": embedding}
98
99 def forward(
100 self,
101 inp: Iterable[Any] = [],
102 *args,
103 **kwargs,
104 ):
105 generated_params, aux_output = self.generate_params(inp, *args, **kwargs)
106
107 assert generated_params.shape[-1] >= self.num_target_parameters
108
109 return self._target(generated_params, *inp)
110
111 @property
112 def device(self) -> torch.device:
113 return self._target.device
114
115 @classmethod
116 def from_target(
117 cls,
118 target_network: ModelMixin,
119 num_target_parameters: Optional[int] = None,
120 embedding_dim: int = 8,
121 num_embeddings: int = 3,
122 weight_chunk_dim: Optional[int] = None,
123 *args,
124 **kwargs,
125 ):
126 if num_target_parameters is None:
127 num_target_parameters = count_params(target_network)
128 if weight_chunk_dim is None:
129 weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings)
130 return cls(
131 target_network=target_network,
132 num_target_parameters=num_target_parameters,
133 embedding_dim=embedding_dim,
134 num_embeddings=num_embeddings,
135 weight_chunk_dim=weight_chunk_dim,
136 *args,
137 **kwargs,
138 )