diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/hypernetwork.py | 138 |
1 files changed, 0 insertions, 138 deletions
diff --git a/models/hypernetwork.py b/models/hypernetwork.py deleted file mode 100644 index fe8a312..0000000 --- a/models/hypernetwork.py +++ /dev/null | |||
@@ -1,138 +0,0 @@ | |||
1 | import math | ||
2 | from typing import Dict, Optional, Iterable, List, Tuple, Any | ||
3 | import copy | ||
4 | import torch | ||
5 | import numpy as np | ||
6 | from torch import nn | ||
7 | from functorch import make_functional, make_functional_with_buffers | ||
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
9 | from diffusers.modeling_utils import ModelMixin | ||
10 | |||
11 | |||
12 | def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int): | ||
13 | weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings) | ||
14 | if weight_chunk_dim != 0: | ||
15 | remainder = num_target_parameters % weight_chunk_dim | ||
16 | if remainder > 0: | ||
17 | diff = math.ceil(remainder / weight_chunk_dim) | ||
18 | num_embeddings += diff | ||
19 | return weight_chunk_dim | ||
20 | |||
21 | |||
22 | def count_params(target: ModelMixin): | ||
23 | return sum([np.prod(p.size()) for p in target.parameters()]) | ||
24 | |||
25 | |||
26 | class FunctionalParamVectorWrapper(ModelMixin): | ||
27 | """ | ||
28 | This wraps a module so that it takes params in the forward pass | ||
29 | """ | ||
30 | |||
31 | def __init__(self, module: ModelMixin): | ||
32 | super().__init__() | ||
33 | |||
34 | self.custom_buffers = None | ||
35 | param_dict = dict(module.named_parameters()) | ||
36 | self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict} | ||
37 | |||
38 | try: | ||
39 | _functional, self.named_params = make_functional(module) | ||
40 | except Exception: | ||
41 | _functional, self.named_params, buffers = make_functional_with_buffers( | ||
42 | module | ||
43 | ) | ||
44 | self.custom_buffers = buffers | ||
45 | self.functional = [_functional] # remove params from being counted | ||
46 | |||
47 | def forward(self, param_vector: torch.Tensor, *args, **kwargs): | ||
48 | params = [] | ||
49 | start = 0 | ||
50 | for p in self.named_params: | ||
51 | end = start + np.prod(p.size()) | ||
52 | params.append(param_vector[start:end].view(p.size())) | ||
53 | start = end | ||
54 | if self.custom_buffers is not None: | ||
55 | return self.functional[0](params, self.custom_buffers, *args, **kwargs) | ||
56 | return self.functional[0](params, *args, **kwargs) | ||
57 | |||
58 | |||
59 | class Hypernetwork(ModelMixin, ConfigMixin): | ||
60 | @register_to_config | ||
61 | def __init__( | ||
62 | self, | ||
63 | target_network: ModelMixin, | ||
64 | num_target_parameters: Optional[int] = None, | ||
65 | embedding_dim: int = 100, | ||
66 | num_embeddings: int = 3, | ||
67 | weight_chunk_dim: Optional[int] = None, | ||
68 | ): | ||
69 | super().__init__() | ||
70 | |||
71 | self._target = FunctionalParamVectorWrapper(target_network) | ||
72 | |||
73 | self.target_weight_shapes = self._target.target_weight_shapes | ||
74 | |||
75 | self.num_target_parameters = num_target_parameters | ||
76 | |||
77 | self.embedding_dim = embedding_dim | ||
78 | self.num_embeddings = num_embeddings | ||
79 | self.weight_chunk_dim = weight_chunk_dim | ||
80 | |||
81 | self.embedding_module = self.make_embedding_module() | ||
82 | self.weight_generator = self.make_weight_generator() | ||
83 | |||
84 | def make_embedding_module(self) -> nn.Module: | ||
85 | return nn.Embedding(self.num_embeddings, self.embedding_dim) | ||
86 | |||
87 | def make_weight_generator(self) -> nn.Module: | ||
88 | return nn.Linear(self.embedding_dim, self.weight_chunk_dim) | ||
89 | |||
90 | def generate_params( | ||
91 | self, inp: Iterable[Any] = [] | ||
92 | ) -> Tuple[torch.Tensor, Dict[str, Any]]: | ||
93 | embedding = self.embedding_module( | ||
94 | torch.arange(self.num_embeddings, device=self.device) | ||
95 | ) | ||
96 | generated_params = self.weight_generator(embedding).view(-1) | ||
97 | return generated_params, {"embedding": embedding} | ||
98 | |||
99 | def forward( | ||
100 | self, | ||
101 | inp: Iterable[Any] = [], | ||
102 | *args, | ||
103 | **kwargs, | ||
104 | ): | ||
105 | generated_params, aux_output = self.generate_params(inp, *args, **kwargs) | ||
106 | |||
107 | assert generated_params.shape[-1] >= self.num_target_parameters | ||
108 | |||
109 | return self._target(generated_params, *inp) | ||
110 | |||
111 | @property | ||
112 | def device(self) -> torch.device: | ||
113 | return self._target.device | ||
114 | |||
115 | @classmethod | ||
116 | def from_target( | ||
117 | cls, | ||
118 | target_network: ModelMixin, | ||
119 | num_target_parameters: Optional[int] = None, | ||
120 | embedding_dim: int = 8, | ||
121 | num_embeddings: int = 3, | ||
122 | weight_chunk_dim: Optional[int] = None, | ||
123 | *args, | ||
124 | **kwargs, | ||
125 | ): | ||
126 | if num_target_parameters is None: | ||
127 | num_target_parameters = count_params(target_network) | ||
128 | if weight_chunk_dim is None: | ||
129 | weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings) | ||
130 | return cls( | ||
131 | target_network=target_network, | ||
132 | num_target_parameters=num_target_parameters, | ||
133 | embedding_dim=embedding_dim, | ||
134 | num_embeddings=num_embeddings, | ||
135 | weight_chunk_dim=weight_chunk_dim, | ||
136 | *args, | ||
137 | **kwargs, | ||
138 | ) | ||