summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 09:44:12 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 09:44:12 +0200
commit1a0161f345191d78a19eec829f9d73b2c2c72f94 (patch)
tree6d7bcc67672ebf26454b3254b4bd9d5ec7e64a16
parentFix (diff)
downloadtextual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.gz
textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.bz2
textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.zip
Update
-rw-r--r--data/csv.py14
-rw-r--r--data/keywords.py13
-rw-r--r--models/clip/embeddings.py3
-rw-r--r--models/lora.py59
-rw-r--r--train_lora.py6
-rw-r--r--train_ti.py6
6 files changed, 69 insertions, 32 deletions
diff --git a/data/csv.py b/data/csv.py
index 3af9925..c5e7aef 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,12 +1,13 @@
1import math 1import math
2import torch
3import json 2import json
4from functools import partial 3from functools import partial
5from pathlib import Path 4from pathlib import Path
6from typing import NamedTuple, Optional, Union, Callable 5from typing import NamedTuple, Optional, Union, Callable
7 6
8from PIL import Image 7from PIL import Image
8import numpy as np
9 9
10import torch
10from torch.utils.data import IterableDataset, DataLoader, random_split 11from torch.utils.data import IterableDataset, DataLoader, random_split
11from torchvision import transforms 12from torchvision import transforms
12from transformers import CLIPTokenizer 13from transformers import CLIPTokenizer
@@ -141,8 +142,8 @@ class VlpnDataItem(NamedTuple):
141 nprompt: str 142 nprompt: str
142 collection: list[str] 143 collection: list[str]
143 144
144 def full_prompt(self, dropout: float = 0, shuffle: bool = False): 145 def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None):
145 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) 146 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator)
146 147
147 148
148def keyword_filter( 149def keyword_filter(
@@ -193,6 +194,7 @@ class VlpnDataModule():
193 train_set_pad: Optional[int] = None, 194 train_set_pad: Optional[int] = None,
194 valid_set_pad: Optional[int] = None, 195 valid_set_pad: Optional[int] = None,
195 generator: Optional[torch.Generator] = None, 196 generator: Optional[torch.Generator] = None,
197 npgenerator: Optional[np.random.Generator] = None,
196 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 198 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 dtype: torch.dtype = torch.float32, 199 dtype: torch.dtype = torch.float32,
198 ): 200 ):
@@ -228,6 +230,7 @@ class VlpnDataModule():
228 self.batch_size = batch_size 230 self.batch_size = batch_size
229 self.dtype = dtype 231 self.dtype = dtype
230 self.generator = generator 232 self.generator = generator
233 self.npgenerator = npgenerator or np.random.default_rng()
231 234
232 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: 235 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
233 tpl_image = template["image"] if "image" in template else "{}" 236 tpl_image = template["image"] if "image" in template else "{}"
@@ -297,6 +300,7 @@ class VlpnDataModule():
297 300
298 items = self.prepare_items(template, expansions, items) 301 items = self.prepare_items(template, expansions, items)
299 items = self.filter_items(items) 302 items = self.filter_items(items)
303 self.npgenerator.shuffle(items)
300 304
301 num_images = len(items) 305 num_images = len(items)
302 306
@@ -370,6 +374,7 @@ class VlpnDataset(IterableDataset):
370 interpolation: str = "bicubic", 374 interpolation: str = "bicubic",
371 color_jitter: bool = True, 375 color_jitter: bool = True,
372 generator: Optional[torch.Generator] = None, 376 generator: Optional[torch.Generator] = None,
377 npgenerator: Optional[np.random.Generator] = None,
373 ): 378 ):
374 self.items = items 379 self.items = items
375 self.batch_size = batch_size 380 self.batch_size = batch_size
@@ -383,6 +388,7 @@ class VlpnDataset(IterableDataset):
383 self.interpolation = interpolations[interpolation] 388 self.interpolation = interpolations[interpolation]
384 self.color_jitter = color_jitter 389 self.color_jitter = color_jitter
385 self.generator = generator 390 self.generator = generator
391 self.npgenerator = npgenerator
386 392
387 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( 393 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
388 [item.instance_image_path for item in self.items], 394 [item.instance_image_path for item in self.items],
@@ -477,7 +483,7 @@ class VlpnDataset(IterableDataset):
477 example["prompt_ids"] = self.get_input_ids(item.full_prompt()) 483 example["prompt_ids"] = self.get_input_ids(item.full_prompt())
478 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 484 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
479 485
480 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) 486 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator))
481 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) 487 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
482 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 488 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
483 489
diff --git a/data/keywords.py b/data/keywords.py
index 629006d..8632d67 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -1,14 +1,23 @@
1from typing import Optional
2
1import numpy as np 3import numpy as np
2 4
3 5
4def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: 6def keywords_to_str(
7 keywords: list[str],
8 undroppable_keywords: list[str] = [],
9 dropout: float = 0,
10 shuffle: bool = False,
11 npgenerator: Optional[np.random.Generator] = None
12) -> str:
5 if dropout != 0: 13 if dropout != 0:
6 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 14 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
7 else: 15 else:
8 keywords = keywords.copy() 16 keywords = keywords.copy()
9 keywords += undroppable_keywords 17 keywords += undroppable_keywords
10 if shuffle: 18 if shuffle:
11 np.random.shuffle(keywords) 19 npgenerator = npgenerator or np.random.default_rng()
20 npgenerator.shuffle(keywords)
12 return ", ".join(keywords) 21 return ", ".join(keywords)
13 22
14 23
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 840f8ae..4444cf9 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
74 save_file({"embed": self.get_embed(input_ids)}, filename) 74 save_file({"embed": self.get_embed(input_ids)}, filename)
75 75
76 def persist(self): 76 def persist(self):
77 self.token_embedding.eval() 77 self.token_embedding.persist()
78 self.token_embedding.merged = False
79 78
80 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 79 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
81 if isinstance(input_ids, list): 80 if isinstance(input_ids, list):
diff --git a/models/lora.py b/models/lora.py
index 89c4b2e..b7fa58f 100644
--- a/models/lora.py
+++ b/models/lora.py
@@ -46,8 +46,8 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
46 self.trainable_ids -= 1 46 self.trainable_ids -= 1
47 47
48 if r > 0: 48 if r > 0:
49 self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) 49 self.lora_A = nn.ParameterList()
50 self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) 50 self.lora_B = nn.Linear(r, embedding_dim, bias=False)
51 self.scaling = self.lora_alpha / self.r 51 self.scaling = self.lora_alpha / self.r
52 self.weight.requires_grad = False 52 self.weight.requires_grad = False
53 53
@@ -83,49 +83,64 @@ class LoraEmbedding(nn.Embedding, LoraLayer):
83 if new_ids.shape[0] == 0: 83 if new_ids.shape[0] == 0:
84 return 84 return
85 85
86 n1 = self.lora_A.shape[1] 86 n1 = len(self.lora_A)
87 n2 = n1 + new_ids.shape[0] 87 n2 = n1 + new_ids.shape[0]
88 self.trainable_ids[new_ids] = torch.arange(n1, n2) 88 self.trainable_ids[new_ids] = torch.arange(n1, n2)
89 for _ in new_ids:
90 self.lora_A.append(self.weight.new_zeros(self.r))
91
92 def persist(self):
93 if self.r > 0:
94 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
95 if weights is not None:
96 self.weight[mask].data += weights
97 self.trainable_ids[:] = -1
98 self.lora_A = nn.ParameterList()
99
100 def get_weights(self, input_ids: torch.Tensor):
101 trainable_ids = self.trainable_ids[input_ids]
102 mask = ~(trainable_ids == -1)
103 trainable_ids = trainable_ids[mask]
104
105 elems = [self.lora_A[id] for id in trainable_ids]
106
107 if len(elems) == 0:
108 return None, mask
109
110 weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling
89 111
90 lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) 112 return weights, mask
91 self.lora_A = lora_A
92 113
93 def reset_parameters(self): 114 def reset_parameters(self):
94 nn.Embedding.reset_parameters(self) 115 nn.Embedding.reset_parameters(self)
95 if hasattr(self, 'lora_A'): 116 if hasattr(self, 'lora_A'):
96 nn.init.zeros_(self.lora_A) 117 self.lora_A = nn.ParameterList()
97 nn.init.normal_(self.lora_B) 118 nn.init.zeros_(self.lora_B.weight)
98 119
99 def train(self, mode: bool = True): 120 def train(self, mode: bool = True):
100 nn.Embedding.train(self, mode) 121 nn.Embedding.train(self, mode)
101 if self.merge_weights and self.merged: 122 if self.merge_weights and self.merged:
102 if self.r > 0: 123 if self.r > 0:
103 mask = ~(self.trainable_ids == -1) 124 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
104 trainable_ids = self.trainable_ids[mask] 125 if weights is not None:
105 self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling 126 self.weight[mask].data -= weights
106 self.merged = False 127 self.merged = False
107 128
108 def eval(self): 129 def eval(self):
109 nn.Embedding.eval(self) 130 nn.Embedding.eval(self)
110 if self.merge_weights and not self.merged: 131 if self.merge_weights and not self.merged:
111 if self.r > 0: 132 if self.r > 0:
112 mask = ~(self.trainable_ids == -1) 133 weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0]))
113 trainable_ids = self.trainable_ids[mask] 134 if weights is not None:
114 self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling 135 self.weight[mask].data += weights
115 self.merged = True 136 self.merged = True
116 137
117 def forward(self, input_ids: torch.Tensor): 138 def forward(self, input_ids: torch.Tensor):
118 result = nn.Embedding.forward(self, input_ids) 139 result = nn.Embedding.forward(self, input_ids)
119 140
120 if self.r > 0 and not self.merged: 141 if self.r > 0 and not self.merged:
121 trainable_ids = self.trainable_ids[input_ids] 142 weights, mask = self.get_weights(input_ids)
122 mask = ~(trainable_ids == -1) 143 if weights is not None:
123 trainable_ids = trainable_ids[mask] 144 result[mask] += weights
124
125 after_A = F.embedding(
126 trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm,
127 self.norm_type, self.scale_grad_by_freq, self.sparse
128 )
129 result[mask] += (after_A @ self.lora_B.T) * self.scaling
130 145
131 return result 146 return result
diff --git a/train_lora.py b/train_lora.py
index 91bda5c..d5dde02 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,9 +13,11 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from peft import LoraConfig, LoraModel 15from peft import LoraConfig, LoraModel
16from slugify import slugify
17import transformers 16import transformers
18 17
18import numpy as np
19from slugify import slugify
20
19from util.files import load_config, load_embeddings_from_dir 21from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 22from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
@@ -873,6 +875,7 @@ def main():
873 ) 875 )
874 876
875 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 877 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
878 data_npgenerator = np.random.default_rng(args.seed)
876 879
877 create_datamodule = partial( 880 create_datamodule = partial(
878 VlpnDataModule, 881 VlpnDataModule,
@@ -893,6 +896,7 @@ def main():
893 valid_set_pad=args.valid_set_pad, 896 valid_set_pad=args.valid_set_pad,
894 dtype=weight_dtype, 897 dtype=weight_dtype,
895 generator=data_generator, 898 generator=data_generator,
899 npgenerator=data_npgenerator,
896 ) 900 )
897 901
898 create_lr_scheduler = partial( 902 create_lr_scheduler = partial(
diff --git a/train_ti.py b/train_ti.py
index 6c57f4b..7f5fb49 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -12,10 +12,12 @@ import torch.utils.checkpoint
12from accelerate import Accelerator 12from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from slugify import slugify
16from timm.models import create_model 15from timm.models import create_model
17import transformers 16import transformers
18 17
18import numpy as np
19from slugify import slugify
20
19from util.files import load_config, load_embeddings_from_dir 21from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 22from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
@@ -852,6 +854,7 @@ def main():
852 ) 854 )
853 855
854 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 856 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
857 data_npgenerator = np.random.default_rng(args.seed)
855 858
856 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 859 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
857 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 860 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
@@ -894,6 +897,7 @@ def main():
894 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 897 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
895 dtype=weight_dtype, 898 dtype=weight_dtype,
896 generator=data_generator, 899 generator=data_generator,
900 npgenerator=data_npgenerator,
897 ) 901 )
898 datamodule.setup() 902 datamodule.setup()
899 903