diff options
| -rw-r--r-- | dreambooth.py | 7 | ||||
| -rw-r--r-- | infer.py | 4 | ||||
| -rw-r--r-- | models/hypernetwork.py | 138 | ||||
| -rw-r--r-- | textual_inversion.py | 10 |
4 files changed, 146 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py index 7b61c45..48fc7f2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -15,14 +15,14 @@ from accelerate import Accelerator | |||
| 15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
| 18 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 19 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 20 | from PIL import Image | 19 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 20 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 22 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 25 | 23 | ||
| 24 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
| 27 | 27 | ||
| 28 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
| @@ -334,7 +334,6 @@ class Checkpointer: | |||
| 334 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 334 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| 335 | ), | 335 | ), |
| 336 | ) | 336 | ) |
| 337 | pipeline.enable_attention_slicing() | ||
| 338 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 337 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
| 339 | 338 | ||
| 340 | del unwrapped | 339 | del unwrapped |
| @@ -359,7 +358,6 @@ class Checkpointer: | |||
| 359 | tokenizer=self.tokenizer, | 358 | tokenizer=self.tokenizer, |
| 360 | scheduler=scheduler, | 359 | scheduler=scheduler, |
| 361 | ).to(self.accelerator.device) | 360 | ).to(self.accelerator.device) |
| 362 | pipeline.enable_attention_slicing() | ||
| 363 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 361 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 364 | 362 | ||
| 365 | train_data = self.datamodule.train_dataloader() | 363 | train_data = self.datamodule.train_dataloader() |
| @@ -561,7 +559,6 @@ def main(): | |||
| 561 | tokenizer=tokenizer, | 559 | tokenizer=tokenizer, |
| 562 | scheduler=scheduler, | 560 | scheduler=scheduler, |
| 563 | ).to(accelerator.device) | 561 | ).to(accelerator.device) |
| 564 | pipeline.enable_attention_slicing() | ||
| 565 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 562 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 566 | 563 | ||
| 567 | with torch.inference_mode(): | 564 | with torch.inference_mode(): |
| @@ -11,8 +11,9 @@ from PIL import Image | |||
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
| 12 | from transformers import CLIPTextModel, CLIPTokenizer | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
| 13 | from slugify import slugify | 13 | from slugify import slugify |
| 14 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | |
| 15 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 16 | 17 | ||
| 17 | 18 | ||
| 18 | torch.backends.cuda.matmul.allow_tf32 = True | 19 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -235,7 +236,6 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype): | |||
| 235 | tokenizer=tokenizer, | 236 | tokenizer=tokenizer, |
| 236 | scheduler=scheduler, | 237 | scheduler=scheduler, |
| 237 | ) | 238 | ) |
| 238 | # pipeline.enable_attention_slicing() | ||
| 239 | pipeline.to("cuda") | 239 | pipeline.to("cuda") |
| 240 | 240 | ||
| 241 | print("Pipeline loaded.") | 241 | print("Pipeline loaded.") |
diff --git a/models/hypernetwork.py b/models/hypernetwork.py new file mode 100644 index 0000000..fe8a312 --- /dev/null +++ b/models/hypernetwork.py | |||
| @@ -0,0 +1,138 @@ | |||
| 1 | import math | ||
| 2 | from typing import Dict, Optional, Iterable, List, Tuple, Any | ||
| 3 | import copy | ||
| 4 | import torch | ||
| 5 | import numpy as np | ||
| 6 | from torch import nn | ||
| 7 | from functorch import make_functional, make_functional_with_buffers | ||
| 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| 9 | from diffusers.modeling_utils import ModelMixin | ||
| 10 | |||
| 11 | |||
| 12 | def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int): | ||
| 13 | weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings) | ||
| 14 | if weight_chunk_dim != 0: | ||
| 15 | remainder = num_target_parameters % weight_chunk_dim | ||
| 16 | if remainder > 0: | ||
| 17 | diff = math.ceil(remainder / weight_chunk_dim) | ||
| 18 | num_embeddings += diff | ||
| 19 | return weight_chunk_dim | ||
| 20 | |||
| 21 | |||
| 22 | def count_params(target: ModelMixin): | ||
| 23 | return sum([np.prod(p.size()) for p in target.parameters()]) | ||
| 24 | |||
| 25 | |||
| 26 | class FunctionalParamVectorWrapper(ModelMixin): | ||
| 27 | """ | ||
| 28 | This wraps a module so that it takes params in the forward pass | ||
| 29 | """ | ||
| 30 | |||
| 31 | def __init__(self, module: ModelMixin): | ||
| 32 | super().__init__() | ||
| 33 | |||
| 34 | self.custom_buffers = None | ||
| 35 | param_dict = dict(module.named_parameters()) | ||
| 36 | self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict} | ||
| 37 | |||
| 38 | try: | ||
| 39 | _functional, self.named_params = make_functional(module) | ||
| 40 | except Exception: | ||
| 41 | _functional, self.named_params, buffers = make_functional_with_buffers( | ||
| 42 | module | ||
| 43 | ) | ||
| 44 | self.custom_buffers = buffers | ||
| 45 | self.functional = [_functional] # remove params from being counted | ||
| 46 | |||
| 47 | def forward(self, param_vector: torch.Tensor, *args, **kwargs): | ||
| 48 | params = [] | ||
| 49 | start = 0 | ||
| 50 | for p in self.named_params: | ||
| 51 | end = start + np.prod(p.size()) | ||
| 52 | params.append(param_vector[start:end].view(p.size())) | ||
| 53 | start = end | ||
| 54 | if self.custom_buffers is not None: | ||
| 55 | return self.functional[0](params, self.custom_buffers, *args, **kwargs) | ||
| 56 | return self.functional[0](params, *args, **kwargs) | ||
| 57 | |||
| 58 | |||
| 59 | class Hypernetwork(ModelMixin, ConfigMixin): | ||
| 60 | @register_to_config | ||
| 61 | def __init__( | ||
| 62 | self, | ||
| 63 | target_network: ModelMixin, | ||
| 64 | num_target_parameters: Optional[int] = None, | ||
| 65 | embedding_dim: int = 100, | ||
| 66 | num_embeddings: int = 3, | ||
| 67 | weight_chunk_dim: Optional[int] = None, | ||
| 68 | ): | ||
| 69 | super().__init__() | ||
| 70 | |||
| 71 | self._target = FunctionalParamVectorWrapper(target_network) | ||
| 72 | |||
| 73 | self.target_weight_shapes = self._target.target_weight_shapes | ||
| 74 | |||
| 75 | self.num_target_parameters = num_target_parameters | ||
| 76 | |||
| 77 | self.embedding_dim = embedding_dim | ||
| 78 | self.num_embeddings = num_embeddings | ||
| 79 | self.weight_chunk_dim = weight_chunk_dim | ||
| 80 | |||
| 81 | self.embedding_module = self.make_embedding_module() | ||
| 82 | self.weight_generator = self.make_weight_generator() | ||
| 83 | |||
| 84 | def make_embedding_module(self) -> nn.Module: | ||
| 85 | return nn.Embedding(self.num_embeddings, self.embedding_dim) | ||
| 86 | |||
| 87 | def make_weight_generator(self) -> nn.Module: | ||
| 88 | return nn.Linear(self.embedding_dim, self.weight_chunk_dim) | ||
| 89 | |||
| 90 | def generate_params( | ||
| 91 | self, inp: Iterable[Any] = [] | ||
| 92 | ) -> Tuple[torch.Tensor, Dict[str, Any]]: | ||
| 93 | embedding = self.embedding_module( | ||
| 94 | torch.arange(self.num_embeddings, device=self.device) | ||
| 95 | ) | ||
| 96 | generated_params = self.weight_generator(embedding).view(-1) | ||
| 97 | return generated_params, {"embedding": embedding} | ||
| 98 | |||
| 99 | def forward( | ||
| 100 | self, | ||
| 101 | inp: Iterable[Any] = [], | ||
| 102 | *args, | ||
| 103 | **kwargs, | ||
| 104 | ): | ||
| 105 | generated_params, aux_output = self.generate_params(inp, *args, **kwargs) | ||
| 106 | |||
| 107 | assert generated_params.shape[-1] >= self.num_target_parameters | ||
| 108 | |||
| 109 | return self._target(generated_params, *inp) | ||
| 110 | |||
| 111 | @property | ||
| 112 | def device(self) -> torch.device: | ||
| 113 | return self._target.device | ||
| 114 | |||
| 115 | @classmethod | ||
| 116 | def from_target( | ||
| 117 | cls, | ||
| 118 | target_network: ModelMixin, | ||
| 119 | num_target_parameters: Optional[int] = None, | ||
| 120 | embedding_dim: int = 8, | ||
| 121 | num_embeddings: int = 3, | ||
| 122 | weight_chunk_dim: Optional[int] = None, | ||
| 123 | *args, | ||
| 124 | **kwargs, | ||
| 125 | ): | ||
| 126 | if num_target_parameters is None: | ||
| 127 | num_target_parameters = count_params(target_network) | ||
| 128 | if weight_chunk_dim is None: | ||
| 129 | weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings) | ||
| 130 | return cls( | ||
| 131 | target_network=target_network, | ||
| 132 | num_target_parameters=num_target_parameters, | ||
| 133 | embedding_dim=embedding_dim, | ||
| 134 | num_embeddings=num_embeddings, | ||
| 135 | weight_chunk_dim=weight_chunk_dim, | ||
| 136 | *args, | ||
| 137 | **kwargs, | ||
| 138 | ) | ||
diff --git a/textual_inversion.py b/textual_inversion.py index 09871d4..e641cab 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -16,14 +16,14 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 19 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 20 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 21 | from PIL import Image | 20 | from PIL import Image |
| 22 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 24 | from slugify import slugify | 23 | from slugify import slugify |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 26 | 24 | ||
| 25 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 27 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 28 | 28 | ||
| 29 | logger = get_logger(__name__) | 29 | logger = get_logger(__name__) |
| @@ -388,7 +388,6 @@ class Checkpointer: | |||
| 388 | tokenizer=self.tokenizer, | 388 | tokenizer=self.tokenizer, |
| 389 | scheduler=scheduler, | 389 | scheduler=scheduler, |
| 390 | ).to(self.accelerator.device) | 390 | ).to(self.accelerator.device) |
| 391 | pipeline.enable_attention_slicing() | ||
| 392 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 391 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 393 | 392 | ||
| 394 | train_data = self.datamodule.train_dataloader() | 393 | train_data = self.datamodule.train_dataloader() |
| @@ -518,8 +517,8 @@ def main(): | |||
| 518 | if args.gradient_checkpointing: | 517 | if args.gradient_checkpointing: |
| 519 | unet.enable_gradient_checkpointing() | 518 | unet.enable_gradient_checkpointing() |
| 520 | 519 | ||
| 521 | slice_size = unet.config.attention_head_dim // 2 | 520 | # slice_size = unet.config.attention_head_dim // 2 |
| 522 | unet.set_attention_slice(slice_size) | 521 | # unet.set_attention_slice(slice_size) |
| 523 | 522 | ||
| 524 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 523 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 525 | text_encoder.resize_token_embeddings(len(tokenizer)) | 524 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| @@ -639,7 +638,6 @@ def main(): | |||
| 639 | tokenizer=tokenizer, | 638 | tokenizer=tokenizer, |
| 640 | scheduler=scheduler, | 639 | scheduler=scheduler, |
| 641 | ).to(accelerator.device) | 640 | ).to(accelerator.device) |
| 642 | pipeline.enable_attention_slicing() | ||
| 643 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 641 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 644 | 642 | ||
| 645 | with torch.inference_mode(): | 643 | with torch.inference_mode(): |
