diff options
| -rw-r--r-- | data/csv.py | 14 | ||||
| -rw-r--r-- | data/keywords.py | 13 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 3 | ||||
| -rw-r--r-- | models/lora.py | 59 | ||||
| -rw-r--r-- | train_lora.py | 6 | ||||
| -rw-r--r-- | train_ti.py | 6 |
6 files changed, 69 insertions, 32 deletions
diff --git a/data/csv.py b/data/csv.py index 3af9925..c5e7aef 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,12 +1,13 @@ | |||
| 1 | import math | 1 | import math |
| 2 | import torch | ||
| 3 | import json | 2 | import json |
| 4 | from functools import partial | 3 | from functools import partial |
| 5 | from pathlib import Path | 4 | from pathlib import Path |
| 6 | from typing import NamedTuple, Optional, Union, Callable | 5 | from typing import NamedTuple, Optional, Union, Callable |
| 7 | 6 | ||
| 8 | from PIL import Image | 7 | from PIL import Image |
| 8 | import numpy as np | ||
| 9 | 9 | ||
| 10 | import torch | ||
| 10 | from torch.utils.data import IterableDataset, DataLoader, random_split | 11 | from torch.utils.data import IterableDataset, DataLoader, random_split |
| 11 | from torchvision import transforms | 12 | from torchvision import transforms |
| 12 | from transformers import CLIPTokenizer | 13 | from transformers import CLIPTokenizer |
| @@ -141,8 +142,8 @@ class VlpnDataItem(NamedTuple): | |||
| 141 | nprompt: str | 142 | nprompt: str |
| 142 | collection: list[str] | 143 | collection: list[str] |
| 143 | 144 | ||
| 144 | def full_prompt(self, dropout: float = 0, shuffle: bool = False): | 145 | def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): |
| 145 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) | 146 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) |
| 146 | 147 | ||
| 147 | 148 | ||
| 148 | def keyword_filter( | 149 | def keyword_filter( |
| @@ -193,6 +194,7 @@ class VlpnDataModule(): | |||
| 193 | train_set_pad: Optional[int] = None, | 194 | train_set_pad: Optional[int] = None, |
| 194 | valid_set_pad: Optional[int] = None, | 195 | valid_set_pad: Optional[int] = None, |
| 195 | generator: Optional[torch.Generator] = None, | 196 | generator: Optional[torch.Generator] = None, |
| 197 | npgenerator: Optional[np.random.Generator] = None, | ||
| 196 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 198 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 197 | dtype: torch.dtype = torch.float32, | 199 | dtype: torch.dtype = torch.float32, |
| 198 | ): | 200 | ): |
| @@ -228,6 +230,7 @@ class VlpnDataModule(): | |||
| 228 | self.batch_size = batch_size | 230 | self.batch_size = batch_size |
| 229 | self.dtype = dtype | 231 | self.dtype = dtype |
| 230 | self.generator = generator | 232 | self.generator = generator |
| 233 | self.npgenerator = npgenerator or np.random.default_rng() | ||
| 231 | 234 | ||
| 232 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 235 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
| 233 | tpl_image = template["image"] if "image" in template else "{}" | 236 | tpl_image = template["image"] if "image" in template else "{}" |
| @@ -297,6 +300,7 @@ class VlpnDataModule(): | |||
| 297 | 300 | ||
| 298 | items = self.prepare_items(template, expansions, items) | 301 | items = self.prepare_items(template, expansions, items) |
| 299 | items = self.filter_items(items) | 302 | items = self.filter_items(items) |
| 303 | self.npgenerator.shuffle(items) | ||
| 300 | 304 | ||
| 301 | num_images = len(items) | 305 | num_images = len(items) |
| 302 | 306 | ||
| @@ -370,6 +374,7 @@ class VlpnDataset(IterableDataset): | |||
| 370 | interpolation: str = "bicubic", | 374 | interpolation: str = "bicubic", |
| 371 | color_jitter: bool = True, | 375 | color_jitter: bool = True, |
| 372 | generator: Optional[torch.Generator] = None, | 376 | generator: Optional[torch.Generator] = None, |
| 377 | npgenerator: Optional[np.random.Generator] = None, | ||
| 373 | ): | 378 | ): |
| 374 | self.items = items | 379 | self.items = items |
| 375 | self.batch_size = batch_size | 380 | self.batch_size = batch_size |
| @@ -383,6 +388,7 @@ class VlpnDataset(IterableDataset): | |||
| 383 | self.interpolation = interpolations[interpolation] | 388 | self.interpolation = interpolations[interpolation] |
| 384 | self.color_jitter = color_jitter | 389 | self.color_jitter = color_jitter |
| 385 | self.generator = generator | 390 | self.generator = generator |
| 391 | self.npgenerator = npgenerator | ||
| 386 | 392 | ||
| 387 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 393 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
| 388 | [item.instance_image_path for item in self.items], | 394 | [item.instance_image_path for item in self.items], |
| @@ -477,7 +483,7 @@ class VlpnDataset(IterableDataset): | |||
| 477 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) | 483 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
| 478 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 484 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 479 | 485 | ||
| 480 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) | 486 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) |
| 481 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 487 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
| 482 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 488 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
| 483 | 489 | ||
diff --git a/data/keywords.py b/data/keywords.py index 629006d..8632d67 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -1,14 +1,23 @@ | |||
| 1 | from typing import Optional | ||
| 2 | |||
| 1 | import numpy as np | 3 | import numpy as np |
| 2 | 4 | ||
| 3 | 5 | ||
| 4 | def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: | 6 | def keywords_to_str( |
| 7 | keywords: list[str], | ||
| 8 | undroppable_keywords: list[str] = [], | ||
| 9 | dropout: float = 0, | ||
| 10 | shuffle: bool = False, | ||
| 11 | npgenerator: Optional[np.random.Generator] = None | ||
| 12 | ) -> str: | ||
| 5 | if dropout != 0: | 13 | if dropout != 0: |
| 6 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
| 7 | else: | 15 | else: |
| 8 | keywords = keywords.copy() | 16 | keywords = keywords.copy() |
| 9 | keywords += undroppable_keywords | 17 | keywords += undroppable_keywords |
| 10 | if shuffle: | 18 | if shuffle: |
| 11 | np.random.shuffle(keywords) | 19 | npgenerator = npgenerator or np.random.default_rng() |
| 20 | npgenerator.shuffle(keywords) | ||
| 12 | return ", ".join(keywords) | 21 | return ", ".join(keywords) |
| 13 | 22 | ||
| 14 | 23 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 840f8ae..4444cf9 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 74 | save_file({"embed": self.get_embed(input_ids)}, filename) | 74 | save_file({"embed": self.get_embed(input_ids)}, filename) |
| 75 | 75 | ||
| 76 | def persist(self): | 76 | def persist(self): |
| 77 | self.token_embedding.eval() | 77 | self.token_embedding.persist() |
| 78 | self.token_embedding.merged = False | ||
| 79 | 78 | ||
| 80 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 79 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 81 | if isinstance(input_ids, list): | 80 | if isinstance(input_ids, list): |
diff --git a/models/lora.py b/models/lora.py index 89c4b2e..b7fa58f 100644 --- a/models/lora.py +++ b/models/lora.py | |||
| @@ -46,8 +46,8 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 46 | self.trainable_ids -= 1 | 46 | self.trainable_ids -= 1 |
| 47 | 47 | ||
| 48 | if r > 0: | 48 | if r > 0: |
| 49 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) | 49 | self.lora_A = nn.ParameterList() |
| 50 | self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) | 50 | self.lora_B = nn.Linear(r, embedding_dim, bias=False) |
| 51 | self.scaling = self.lora_alpha / self.r | 51 | self.scaling = self.lora_alpha / self.r |
| 52 | self.weight.requires_grad = False | 52 | self.weight.requires_grad = False |
| 53 | 53 | ||
| @@ -83,49 +83,64 @@ class LoraEmbedding(nn.Embedding, LoraLayer): | |||
| 83 | if new_ids.shape[0] == 0: | 83 | if new_ids.shape[0] == 0: |
| 84 | return | 84 | return |
| 85 | 85 | ||
| 86 | n1 = self.lora_A.shape[1] | 86 | n1 = len(self.lora_A) |
| 87 | n2 = n1 + new_ids.shape[0] | 87 | n2 = n1 + new_ids.shape[0] |
| 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) | 88 | self.trainable_ids[new_ids] = torch.arange(n1, n2) |
| 89 | for _ in new_ids: | ||
| 90 | self.lora_A.append(self.weight.new_zeros(self.r)) | ||
| 89 | 91 | ||
| 90 | lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) | 92 | def persist(self): |
| 91 | self.lora_A = lora_A | 93 | if self.r > 0: |
| 94 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) | ||
| 95 | if weights is not None: | ||
| 96 | self.weight[mask].data += weights | ||
| 97 | self.trainable_ids[:] = -1 | ||
| 98 | self.lora_A = nn.ParameterList() | ||
| 99 | |||
| 100 | def get_weights(self, input_ids: torch.Tensor): | ||
| 101 | trainable_ids = self.trainable_ids[input_ids] | ||
| 102 | mask = ~(trainable_ids == -1) | ||
| 103 | trainable_ids = trainable_ids[mask] | ||
| 104 | |||
| 105 | elems = [self.lora_A[id] for id in trainable_ids] | ||
| 106 | |||
| 107 | if len(elems) == 0: | ||
| 108 | return None, mask | ||
| 109 | |||
| 110 | weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling | ||
| 111 | |||
| 112 | return weights, mask | ||
| 92 | 113 | ||
| 93 | def reset_parameters(self): | 114 | def reset_parameters(self): |
| 94 | nn.Embedding.reset_parameters(self) | 115 | nn.Embedding.reset_parameters(self) |
| 95 | if hasattr(self, 'lora_A'): | 116 | if hasattr(self, 'lora_A'): |
| 96 | nn.init.zeros_(self.lora_A) | 117 | self.lora_A = nn.ParameterList() |
| 97 | nn.init.normal_(self.lora_B) | 118 | nn.init.zeros_(self.lora_B.weight) |
| 98 | 119 | ||
| 99 | def train(self, mode: bool = True): | 120 | def train(self, mode: bool = True): |
| 100 | nn.Embedding.train(self, mode) | 121 | nn.Embedding.train(self, mode) |
| 101 | if self.merge_weights and self.merged: | 122 | if self.merge_weights and self.merged: |
| 102 | if self.r > 0: | 123 | if self.r > 0: |
| 103 | mask = ~(self.trainable_ids == -1) | 124 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 104 | trainable_ids = self.trainable_ids[mask] | 125 | if weights is not None: |
| 105 | self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling | 126 | self.weight[mask].data -= weights |
| 106 | self.merged = False | 127 | self.merged = False |
| 107 | 128 | ||
| 108 | def eval(self): | 129 | def eval(self): |
| 109 | nn.Embedding.eval(self) | 130 | nn.Embedding.eval(self) |
| 110 | if self.merge_weights and not self.merged: | 131 | if self.merge_weights and not self.merged: |
| 111 | if self.r > 0: | 132 | if self.r > 0: |
| 112 | mask = ~(self.trainable_ids == -1) | 133 | weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) |
| 113 | trainable_ids = self.trainable_ids[mask] | 134 | if weights is not None: |
| 114 | self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling | 135 | self.weight[mask].data += weights |
| 115 | self.merged = True | 136 | self.merged = True |
| 116 | 137 | ||
| 117 | def forward(self, input_ids: torch.Tensor): | 138 | def forward(self, input_ids: torch.Tensor): |
| 118 | result = nn.Embedding.forward(self, input_ids) | 139 | result = nn.Embedding.forward(self, input_ids) |
| 119 | 140 | ||
| 120 | if self.r > 0 and not self.merged: | 141 | if self.r > 0 and not self.merged: |
| 121 | trainable_ids = self.trainable_ids[input_ids] | 142 | weights, mask = self.get_weights(input_ids) |
| 122 | mask = ~(trainable_ids == -1) | 143 | if weights is not None: |
| 123 | trainable_ids = trainable_ids[mask] | 144 | result[mask] += weights |
| 124 | |||
| 125 | after_A = F.embedding( | ||
| 126 | trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, | ||
| 127 | self.norm_type, self.scale_grad_by_freq, self.sparse | ||
| 128 | ) | ||
| 129 | result[mask] += (after_A @ self.lora_B.T) * self.scaling | ||
| 130 | 145 | ||
| 131 | return result | 146 | return result |
diff --git a/train_lora.py b/train_lora.py index 91bda5c..d5dde02 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -13,9 +13,11 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from peft import LoraConfig, LoraModel | 15 | from peft import LoraConfig, LoraModel |
| 16 | from slugify import slugify | ||
| 17 | import transformers | 16 | import transformers |
| 18 | 17 | ||
| 18 | import numpy as np | ||
| 19 | from slugify import slugify | ||
| 20 | |||
| 19 | from util.files import load_config, load_embeddings_from_dir | 21 | from util.files import load_config, load_embeddings_from_dir |
| 20 | from data.csv import VlpnDataModule, keyword_filter | 22 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| @@ -873,6 +875,7 @@ def main(): | |||
| 873 | ) | 875 | ) |
| 874 | 876 | ||
| 875 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 877 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 878 | data_npgenerator = np.random.default_rng(args.seed) | ||
| 876 | 879 | ||
| 877 | create_datamodule = partial( | 880 | create_datamodule = partial( |
| 878 | VlpnDataModule, | 881 | VlpnDataModule, |
| @@ -893,6 +896,7 @@ def main(): | |||
| 893 | valid_set_pad=args.valid_set_pad, | 896 | valid_set_pad=args.valid_set_pad, |
| 894 | dtype=weight_dtype, | 897 | dtype=weight_dtype, |
| 895 | generator=data_generator, | 898 | generator=data_generator, |
| 899 | npgenerator=data_npgenerator, | ||
| 896 | ) | 900 | ) |
| 897 | 901 | ||
| 898 | create_lr_scheduler = partial( | 902 | create_lr_scheduler = partial( |
diff --git a/train_ti.py b/train_ti.py index 6c57f4b..7f5fb49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -12,10 +12,12 @@ import torch.utils.checkpoint | |||
| 12 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | ||
| 16 | from timm.models import create_model | 15 | from timm.models import create_model |
| 17 | import transformers | 16 | import transformers |
| 18 | 17 | ||
| 18 | import numpy as np | ||
| 19 | from slugify import slugify | ||
| 20 | |||
| 19 | from util.files import load_config, load_embeddings_from_dir | 21 | from util.files import load_config, load_embeddings_from_dir |
| 20 | from data.csv import VlpnDataModule, keyword_filter | 22 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from training.functional import train, add_placeholder_tokens, get_models | 23 | from training.functional import train, add_placeholder_tokens, get_models |
| @@ -852,6 +854,7 @@ def main(): | |||
| 852 | ) | 854 | ) |
| 853 | 855 | ||
| 854 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 856 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 857 | data_npgenerator = np.random.default_rng(args.seed) | ||
| 855 | 858 | ||
| 856 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 859 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
| 857 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 860 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| @@ -894,6 +897,7 @@ def main(): | |||
| 894 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 897 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
| 895 | dtype=weight_dtype, | 898 | dtype=weight_dtype, |
| 896 | generator=data_generator, | 899 | generator=data_generator, |
| 900 | npgenerator=data_npgenerator, | ||
| 897 | ) | 901 | ) |
| 898 | datamodule.setup() | 902 | datamodule.setup() |
| 899 | 903 | ||
