summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py7
-rw-r--r--infer.py4
-rw-r--r--models/hypernetwork.py138
-rw-r--r--textual_inversion.py10
4 files changed, 146 insertions, 13 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 7b61c45..48fc7f2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -15,14 +15,14 @@ from accelerate import Accelerator
15from accelerate.logging import get_logger 15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
20from PIL import Image 19from PIL import Image
21from tqdm.auto import tqdm 20from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 21from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 22from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25 23
24from schedulers.scheduling_euler_a import EulerAScheduler
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import CSVDataModule 26from data.csv import CSVDataModule
27 27
28logger = get_logger(__name__) 28logger = get_logger(__name__)
@@ -334,7 +334,6 @@ class Checkpointer:
334 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 334 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
335 ), 335 ),
336 ) 336 )
337 pipeline.enable_attention_slicing()
338 pipeline.save_pretrained(self.output_dir.joinpath("model")) 337 pipeline.save_pretrained(self.output_dir.joinpath("model"))
339 338
340 del unwrapped 339 del unwrapped
@@ -359,7 +358,6 @@ class Checkpointer:
359 tokenizer=self.tokenizer, 358 tokenizer=self.tokenizer,
360 scheduler=scheduler, 359 scheduler=scheduler,
361 ).to(self.accelerator.device) 360 ).to(self.accelerator.device)
362 pipeline.enable_attention_slicing()
363 pipeline.set_progress_bar_config(dynamic_ncols=True) 361 pipeline.set_progress_bar_config(dynamic_ncols=True)
364 362
365 train_data = self.datamodule.train_dataloader() 363 train_data = self.datamodule.train_dataloader()
@@ -561,7 +559,6 @@ def main():
561 tokenizer=tokenizer, 559 tokenizer=tokenizer,
562 scheduler=scheduler, 560 scheduler=scheduler,
563 ).to(accelerator.device) 561 ).to(accelerator.device)
564 pipeline.enable_attention_slicing()
565 pipeline.set_progress_bar_config(dynamic_ncols=True) 562 pipeline.set_progress_bar_config(dynamic_ncols=True)
566 563
567 with torch.inference_mode(): 564 with torch.inference_mode():
diff --git a/infer.py b/infer.py
index a542534..70851fd 100644
--- a/infer.py
+++ b/infer.py
@@ -11,8 +11,9 @@ from PIL import Image
11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
12from transformers import CLIPTextModel, CLIPTokenizer 12from transformers import CLIPTextModel, CLIPTokenizer
13from slugify import slugify 13from slugify import slugify
14from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 14
15from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16 17
17 18
18torch.backends.cuda.matmul.allow_tf32 = True 19torch.backends.cuda.matmul.allow_tf32 = True
@@ -235,7 +236,6 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype):
235 tokenizer=tokenizer, 236 tokenizer=tokenizer,
236 scheduler=scheduler, 237 scheduler=scheduler,
237 ) 238 )
238 # pipeline.enable_attention_slicing()
239 pipeline.to("cuda") 239 pipeline.to("cuda")
240 240
241 print("Pipeline loaded.") 241 print("Pipeline loaded.")
diff --git a/models/hypernetwork.py b/models/hypernetwork.py
new file mode 100644
index 0000000..fe8a312
--- /dev/null
+++ b/models/hypernetwork.py
@@ -0,0 +1,138 @@
1import math
2from typing import Dict, Optional, Iterable, List, Tuple, Any
3import copy
4import torch
5import numpy as np
6from torch import nn
7from functorch import make_functional, make_functional_with_buffers
8from diffusers.configuration_utils import ConfigMixin, register_to_config
9from diffusers.modeling_utils import ModelMixin
10
11
12def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int):
13 weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings)
14 if weight_chunk_dim != 0:
15 remainder = num_target_parameters % weight_chunk_dim
16 if remainder > 0:
17 diff = math.ceil(remainder / weight_chunk_dim)
18 num_embeddings += diff
19 return weight_chunk_dim
20
21
22def count_params(target: ModelMixin):
23 return sum([np.prod(p.size()) for p in target.parameters()])
24
25
26class FunctionalParamVectorWrapper(ModelMixin):
27 """
28 This wraps a module so that it takes params in the forward pass
29 """
30
31 def __init__(self, module: ModelMixin):
32 super().__init__()
33
34 self.custom_buffers = None
35 param_dict = dict(module.named_parameters())
36 self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict}
37
38 try:
39 _functional, self.named_params = make_functional(module)
40 except Exception:
41 _functional, self.named_params, buffers = make_functional_with_buffers(
42 module
43 )
44 self.custom_buffers = buffers
45 self.functional = [_functional] # remove params from being counted
46
47 def forward(self, param_vector: torch.Tensor, *args, **kwargs):
48 params = []
49 start = 0
50 for p in self.named_params:
51 end = start + np.prod(p.size())
52 params.append(param_vector[start:end].view(p.size()))
53 start = end
54 if self.custom_buffers is not None:
55 return self.functional[0](params, self.custom_buffers, *args, **kwargs)
56 return self.functional[0](params, *args, **kwargs)
57
58
59class Hypernetwork(ModelMixin, ConfigMixin):
60 @register_to_config
61 def __init__(
62 self,
63 target_network: ModelMixin,
64 num_target_parameters: Optional[int] = None,
65 embedding_dim: int = 100,
66 num_embeddings: int = 3,
67 weight_chunk_dim: Optional[int] = None,
68 ):
69 super().__init__()
70
71 self._target = FunctionalParamVectorWrapper(target_network)
72
73 self.target_weight_shapes = self._target.target_weight_shapes
74
75 self.num_target_parameters = num_target_parameters
76
77 self.embedding_dim = embedding_dim
78 self.num_embeddings = num_embeddings
79 self.weight_chunk_dim = weight_chunk_dim
80
81 self.embedding_module = self.make_embedding_module()
82 self.weight_generator = self.make_weight_generator()
83
84 def make_embedding_module(self) -> nn.Module:
85 return nn.Embedding(self.num_embeddings, self.embedding_dim)
86
87 def make_weight_generator(self) -> nn.Module:
88 return nn.Linear(self.embedding_dim, self.weight_chunk_dim)
89
90 def generate_params(
91 self, inp: Iterable[Any] = []
92 ) -> Tuple[torch.Tensor, Dict[str, Any]]:
93 embedding = self.embedding_module(
94 torch.arange(self.num_embeddings, device=self.device)
95 )
96 generated_params = self.weight_generator(embedding).view(-1)
97 return generated_params, {"embedding": embedding}
98
99 def forward(
100 self,
101 inp: Iterable[Any] = [],
102 *args,
103 **kwargs,
104 ):
105 generated_params, aux_output = self.generate_params(inp, *args, **kwargs)
106
107 assert generated_params.shape[-1] >= self.num_target_parameters
108
109 return self._target(generated_params, *inp)
110
111 @property
112 def device(self) -> torch.device:
113 return self._target.device
114
115 @classmethod
116 def from_target(
117 cls,
118 target_network: ModelMixin,
119 num_target_parameters: Optional[int] = None,
120 embedding_dim: int = 8,
121 num_embeddings: int = 3,
122 weight_chunk_dim: Optional[int] = None,
123 *args,
124 **kwargs,
125 ):
126 if num_target_parameters is None:
127 num_target_parameters = count_params(target_network)
128 if weight_chunk_dim is None:
129 weight_chunk_dim = get_weight_chunk_dims(num_target_parameters, num_embeddings)
130 return cls(
131 target_network=target_network,
132 num_target_parameters=num_target_parameters,
133 embedding_dim=embedding_dim,
134 num_embeddings=num_embeddings,
135 weight_chunk_dim=weight_chunk_dim,
136 *args,
137 **kwargs,
138 )
diff --git a/textual_inversion.py b/textual_inversion.py
index 09871d4..e641cab 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -16,14 +16,14 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from schedulers.scheduling_euler_a import EulerAScheduler
20from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
21from PIL import Image 20from PIL import Image
22from tqdm.auto import tqdm 21from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 23from slugify import slugify
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26 24
25from schedulers.scheduling_euler_a import EulerAScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 27from data.csv import CSVDataModule
28 28
29logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -388,7 +388,6 @@ class Checkpointer:
388 tokenizer=self.tokenizer, 388 tokenizer=self.tokenizer,
389 scheduler=scheduler, 389 scheduler=scheduler,
390 ).to(self.accelerator.device) 390 ).to(self.accelerator.device)
391 pipeline.enable_attention_slicing()
392 pipeline.set_progress_bar_config(dynamic_ncols=True) 391 pipeline.set_progress_bar_config(dynamic_ncols=True)
393 392
394 train_data = self.datamodule.train_dataloader() 393 train_data = self.datamodule.train_dataloader()
@@ -518,8 +517,8 @@ def main():
518 if args.gradient_checkpointing: 517 if args.gradient_checkpointing:
519 unet.enable_gradient_checkpointing() 518 unet.enable_gradient_checkpointing()
520 519
521 slice_size = unet.config.attention_head_dim // 2 520 # slice_size = unet.config.attention_head_dim // 2
522 unet.set_attention_slice(slice_size) 521 # unet.set_attention_slice(slice_size)
523 522
524 # Resize the token embeddings as we are adding new special tokens to the tokenizer 523 # Resize the token embeddings as we are adding new special tokens to the tokenizer
525 text_encoder.resize_token_embeddings(len(tokenizer)) 524 text_encoder.resize_token_embeddings(len(tokenizer))
@@ -639,7 +638,6 @@ def main():
639 tokenizer=tokenizer, 638 tokenizer=tokenizer,
640 scheduler=scheduler, 639 scheduler=scheduler,
641 ).to(accelerator.device) 640 ).to(accelerator.device)
642 pipeline.enable_attention_slicing()
643 pipeline.set_progress_bar_config(dynamic_ncols=True) 641 pipeline.set_progress_bar_config(dynamic_ncols=True)
644 642
645 with torch.inference_mode(): 643 with torch.inference_mode():