From 1eef9a946161fd06b0e72ec804c68f4f0e74b380 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Oct 2022 12:42:21 +0200 Subject: Update --- dreambooth.py | 7 +-- infer.py | 4 +- models/hypernetwork.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++++ textual_inversion.py | 10 ++-- 4 files changed, 146 insertions(+), 13 deletions(-) create mode 100644 models/hypernetwork.py diff --git a/dreambooth.py b/dreambooth.py index 7b61c45..48fc7f2 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -15,14 +15,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel -from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from schedulers.scheduling_euler_a import EulerAScheduler +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule logger = get_logger(__name__) @@ -334,7 +334,6 @@ class Checkpointer: beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True ), ) - pipeline.enable_attention_slicing() pipeline.save_pretrained(self.output_dir.joinpath("model")) del unwrapped @@ -359,7 +358,6 @@ class Checkpointer: tokenizer=self.tokenizer, scheduler=scheduler, ).to(self.accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() @@ -561,7 +559,6 @@ def main(): tokenizer=tokenizer, scheduler=scheduler, ).to(accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): diff --git a/infer.py b/infer.py index a542534..70851fd 100644 --- a/infer.py +++ b/infer.py @@ -11,8 +11,9 @@ from PIL import Image from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion + from schedulers.scheduling_euler_a import EulerAScheduler +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion torch.backends.cuda.matmul.allow_tf32 = True @@ -235,7 +236,6 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype): tokenizer=tokenizer, scheduler=scheduler, ) - # pipeline.enable_attention_slicing() pipeline.to("cuda") print("Pipeline loaded.") diff --git a/models/hypernetwork.py b/models/hypernetwork.py new file mode 100644 index 0000000..fe8a312 --- /dev/null +++ b/models/hypernetwork.py @@ -0,0 +1,138 @@ +import math +from typing import Dict, Optional, Iterable, List, Tuple, Any +import copy +import torch +import numpy as np +from torch import nn +from functorch import make_functional, make_functional_with_buffers +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin + + +def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int): + weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings) + if weight_chunk_dim != 0: + remainder = num_target_parameters % weight_chunk_dim + if remainder > 0: + diff = math.ceil(remainder / weight_chunk_dim) + num_embeddings += diff + return weight_chunk_dim + + +def count_params(target: ModelMixin): + return sum([np.prod(p.size()) for p in target.parameters()]) + + +class FunctionalParamVectorWrapper(ModelMixin): + """ + This wraps a module so that it takes params in the forward pass + """ + + def __init__(self, module: ModelMixin): + super().__init__() + + self.custom_buffers = None + param_dict = dict(module.named_parameters()) + self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict} + + try: + _functional, self.named_params = make_functional(module) + except Exception: + _functional, self.named_params, buffers = make_functional_with_buffers( + module + ) + self.custom_buffers = buffers + self.functional = [_functional] # remove params from being counted + + def forward(self, param_vector: torch.Tensor, *args, **kwargs): + params = [] + start = 0 + for p in self.named_params: + end = start + np.prod(p.size()) + params.append(param_vector[start:end].view(p.size())) + start = end + if self.custom_buffers is not None: + return self.functional[0](params, self.custom_buffers, *args, **kwargs) + return self.functional[0](params, *args, **kwargs) + + +class Hypernetwork(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + target_network: ModelMixin, + num_target_parameters: Optional[int] = None, + embedding_dim: int = 100, + num_embeddings: int = 3, + weight_chunk_dim: Optional[int] = None, + ): + super().__init__() + + self._target = FunctionalParamVectorWrapper(target_network) + + self.target_weight_shapes = self._target.target_weight_shapes + + self.num_target_parameters = num_target_parameters + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.weight_chunk_dim = weight_chunk_dim + + self.embedding_module = self.make_embedding_module() + self.weight_generator = self.make_weight_generator() + + def make_embedding_module(self) -> nn.Module: + return nn.Embedding(self.num_embeddings, self.embedding_dim) + + def make_weight_generator(self) -> nn.Module: + return nn.Linear(self.embedding_dim, self.weight_chunk_dim) + + def generate_params( + self, inp: Iterable[Any] = [] + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + embedding = self.embedding_module( + torch.arange(self.num_embeddings, device=self.device) + ) + generated_params = self.weight_generator(embedding).view(-1) + return generated_params, {"embedding": embedding} + + def forward( + self, + inp: Iterable[Any] = [], + *args, + **kwargs, + ): + generated_params, aux_output = self.generate_params(inp, *args, **kwargs) + + assert generated_params.shape[-1] >= self.num_target_parameters + + return self._target(generated_params, *inp) + + @property + def device(self) -> torch.device: + return self._target.device + + @classmethod + def from_target( + cls, + target_network: ModelMixin, + num_target_parameters: Optional[int] = None, + embedding_dim: int = 8, + num_embeddings: int = 3, + weight_chunk_dim: Optional[int] = None, + *args, + **kwargs, + ): + if num_target_parameters is None: + num_target_parameters = count_params(target_network) + if weight_chunk_dim is None: + weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings) + return cls( + target_network=target_network, + num_target_parameters=num_target_parameters, + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + weight_chunk_dim=weight_chunk_dim, + *args, + **kwargs, + ) diff --git a/textual_inversion.py b/textual_inversion.py index 09871d4..e641cab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -16,14 +16,14 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel -from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion +from schedulers.scheduling_euler_a import EulerAScheduler +from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule logger = get_logger(__name__) @@ -388,7 +388,6 @@ class Checkpointer: tokenizer=self.tokenizer, scheduler=scheduler, ).to(self.accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) train_data = self.datamodule.train_dataloader() @@ -518,8 +517,8 @@ def main(): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - slice_size = unet.config.attention_head_dim // 2 - unet.set_attention_slice(slice_size) + # slice_size = unet.config.attention_head_dim // 2 + # unet.set_attention_slice(slice_size) # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -639,7 +638,6 @@ def main(): tokenizer=tokenizer, scheduler=scheduler, ).to(accelerator.device) - pipeline.enable_attention_slicing() pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.inference_mode(): -- cgit v1.2.3-70-g09d2