From 1a0161f345191d78a19eec829f9d73b2c2c72f94 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 16 Apr 2023 09:44:12 +0200 Subject: Update --- data/csv.py | 14 +++++++---- data/keywords.py | 13 +++++++++-- models/clip/embeddings.py | 3 +-- models/lora.py | 59 +++++++++++++++++++++++++++++------------------ train_lora.py | 6 ++++- train_ti.py | 6 ++++- 6 files changed, 69 insertions(+), 32 deletions(-) diff --git a/data/csv.py b/data/csv.py index 3af9925..c5e7aef 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,12 +1,13 @@ import math -import torch import json from functools import partial from pathlib import Path from typing import NamedTuple, Optional, Union, Callable from PIL import Image +import numpy as np +import torch from torch.utils.data import IterableDataset, DataLoader, random_split from torchvision import transforms from transformers import CLIPTokenizer @@ -141,8 +142,8 @@ class VlpnDataItem(NamedTuple): nprompt: str collection: list[str] - def full_prompt(self, dropout: float = 0, shuffle: bool = False): - return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle) + def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): + return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) def keyword_filter( @@ -193,6 +194,7 @@ class VlpnDataModule(): train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, generator: Optional[torch.Generator] = None, + npgenerator: Optional[np.random.Generator] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, ): @@ -228,6 +230,7 @@ class VlpnDataModule(): self.batch_size = batch_size self.dtype = dtype self.generator = generator + self.npgenerator = npgenerator or np.random.default_rng() def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" @@ -297,6 +300,7 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) + self.npgenerator.shuffle(items) num_images = len(items) @@ -370,6 +374,7 @@ class VlpnDataset(IterableDataset): interpolation: str = "bicubic", color_jitter: bool = True, generator: Optional[torch.Generator] = None, + npgenerator: Optional[np.random.Generator] = None, ): self.items = items self.batch_size = batch_size @@ -383,6 +388,7 @@ class VlpnDataset(IterableDataset): self.interpolation = interpolations[interpolation] self.color_jitter = color_jitter self.generator = generator + self.npgenerator = npgenerator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( [item.instance_image_path for item in self.items], @@ -477,7 +483,7 @@ class VlpnDataset(IterableDataset): example["prompt_ids"] = self.get_input_ids(item.full_prompt()) example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True)) + example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) diff --git a/data/keywords.py b/data/keywords.py index 629006d..8632d67 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -1,14 +1,23 @@ +from typing import Optional + import numpy as np -def keywords_to_str(keywords: list[str], undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False) -> str: +def keywords_to_str( + keywords: list[str], + undroppable_keywords: list[str] = [], + dropout: float = 0, + shuffle: bool = False, + npgenerator: Optional[np.random.Generator] = None +) -> str: if dropout != 0: keywords = [keyword for keyword in keywords if np.random.random() > dropout] else: keywords = keywords.copy() keywords += undroppable_keywords if shuffle: - np.random.shuffle(keywords) + npgenerator = npgenerator or np.random.default_rng() + npgenerator.shuffle(keywords) return ", ".join(keywords) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 840f8ae..4444cf9 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -74,8 +74,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - self.token_embedding.eval() - self.token_embedding.merged = False + self.token_embedding.persist() def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): diff --git a/models/lora.py b/models/lora.py index 89c4b2e..b7fa58f 100644 --- a/models/lora.py +++ b/models/lora.py @@ -46,8 +46,8 @@ class LoraEmbedding(nn.Embedding, LoraLayer): self.trainable_ids -= 1 if r > 0: - self.lora_A = nn.Parameter(self.weight.new_zeros((r, 0))) - self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) + self.lora_A = nn.ParameterList() + self.lora_B = nn.Linear(r, embedding_dim, bias=False) self.scaling = self.lora_alpha / self.r self.weight.requires_grad = False @@ -83,49 +83,64 @@ class LoraEmbedding(nn.Embedding, LoraLayer): if new_ids.shape[0] == 0: return - n1 = self.lora_A.shape[1] + n1 = len(self.lora_A) n2 = n1 + new_ids.shape[0] self.trainable_ids[new_ids] = torch.arange(n1, n2) + for _ in new_ids: + self.lora_A.append(self.weight.new_zeros(self.r)) + + def persist(self): + if self.r > 0: + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data += weights + self.trainable_ids[:] = -1 + self.lora_A = nn.ParameterList() + + def get_weights(self, input_ids: torch.Tensor): + trainable_ids = self.trainable_ids[input_ids] + mask = ~(trainable_ids == -1) + trainable_ids = trainable_ids[mask] + + elems = [self.lora_A[id] for id in trainable_ids] + + if len(elems) == 0: + return None, mask + + weights = self.lora_B(self.lora_dropout(torch.stack(elems))) * self.scaling - lora_A = nn.Parameter(self.weight.new_zeros((self.r, n2))) - self.lora_A = lora_A + return weights, mask def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): - nn.init.zeros_(self.lora_A) - nn.init.normal_(self.lora_B) + self.lora_A = nn.ParameterList() + nn.init.zeros_(self.lora_B.weight) def train(self, mode: bool = True): nn.Embedding.train(self, mode) if self.merge_weights and self.merged: if self.r > 0: - mask = ~(self.trainable_ids == -1) - trainable_ids = self.trainable_ids[mask] - self.weight[trainable_ids].data -= (self.lora_B @ self.lora_A).T * self.scaling + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data -= weights self.merged = False def eval(self): nn.Embedding.eval(self) if self.merge_weights and not self.merged: if self.r > 0: - mask = ~(self.trainable_ids == -1) - trainable_ids = self.trainable_ids[mask] - self.weight[trainable_ids].data += (self.lora_B @ self.lora_A) * self.scaling + weights, mask = self.get_weights(torch.arange(self.trainable_ids.shape[0])) + if weights is not None: + self.weight[mask].data += weights self.merged = True def forward(self, input_ids: torch.Tensor): result = nn.Embedding.forward(self, input_ids) if self.r > 0 and not self.merged: - trainable_ids = self.trainable_ids[input_ids] - mask = ~(trainable_ids == -1) - trainable_ids = trainable_ids[mask] - - after_A = F.embedding( - trainable_ids, self.lora_A.T, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse - ) - result[mask] += (after_A @ self.lora_B.T) * self.scaling + weights, mask = self.get_weights(input_ids) + if weights is not None: + result[mask] += weights return result diff --git a/train_lora.py b/train_lora.py index 91bda5c..d5dde02 100644 --- a/train_lora.py +++ b/train_lora.py @@ -13,9 +13,11 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, LoraModel -from slugify import slugify import transformers +import numpy as np +from slugify import slugify + from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models @@ -873,6 +875,7 @@ def main(): ) data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + data_npgenerator = np.random.default_rng(args.seed) create_datamodule = partial( VlpnDataModule, @@ -893,6 +896,7 @@ def main(): valid_set_pad=args.valid_set_pad, dtype=weight_dtype, generator=data_generator, + npgenerator=data_npgenerator, ) create_lr_scheduler = partial( diff --git a/train_ti.py b/train_ti.py index 6c57f4b..7f5fb49 100644 --- a/train_ti.py +++ b/train_ti.py @@ -12,10 +12,12 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from slugify import slugify from timm.models import create_model import transformers +import numpy as np +from slugify import slugify + from util.files import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models @@ -852,6 +854,7 @@ def main(): ) data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + data_npgenerator = np.random.default_rng(args.seed) def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( @@ -894,6 +897,7 @@ def main(): filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), dtype=weight_dtype, generator=data_generator, + npgenerator=data_npgenerator, ) datamodule.setup() -- cgit v1.2.3-70-g09d2