From 9f5f70cb2a8919cb07821f264bf0fd75bfa10584 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 8 Apr 2023 17:38:49 +0200 Subject: Update --- models/clip/embeddings.py | 2 +- models/sparse.py | 2 +- train_lora.py | 32 ++++++++++++++++++++++---------- train_ti.py | 23 ++++++++++++++++++----- training/strategy/lora.py | 2 +- training/strategy/ti.py | 12 ++++++------ 6 files changed, 49 insertions(+), 24 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 63a141f..6fda33c 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -96,7 +96,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): save_file({"embed": self.get_embed(input_ids)}, filename) def persist(self): - input_ids = torch.arange(self.token_embedding.num_embeddings) + input_ids = torch.arange(self.token_embedding.num_embeddings, device=self.token_override_embedding.device) embs, mask = self.token_override_embedding(input_ids) if embs is not None: input_ids = input_ids[mask] diff --git a/models/sparse.py b/models/sparse.py index 8910316..d706db5 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -11,7 +11,7 @@ class PseudoSparseEmbedding(nn.Module): self.embedding_dim = embedding_dim self.dtype = dtype self.params = nn.ParameterList() - self.mapping = torch.zeros(0, device=device, dtype=torch.long) + self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) def forward(self, input_ids: torch.LongTensor): ids = self.mapping[input_ids.to(self.mapping.device)] diff --git a/train_lora.py b/train_lora.py index 1626be6..e4b5546 100644 --- a/train_lora.py +++ b/train_lora.py @@ -92,6 +92,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--filter_tokens", + type=str, + nargs='*', + help="Tokens to filter the dataset by." + ) parser.add_argument( "--initializer_noise", type=float, @@ -592,6 +598,12 @@ def parse_args(): if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") + if args.filter_tokens is None: + args.filter_tokens = args.placeholder_tokens.copy() + + if isinstance(args.filter_tokens, str): + args.filter_tokens = [args.filter_tokens] + if isinstance(args.collection, str): args.collection = [args.collection] @@ -890,7 +902,7 @@ def main(): pti_datamodule = create_datamodule( batch_size=args.pti_batch_size, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), ) pti_datamodule.setup() @@ -906,7 +918,7 @@ def main(): pti_optimizer = create_optimizer( [ { - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), "lr": args.learning_rate_pti, "weight_decay": 0, }, @@ -937,7 +949,7 @@ def main(): sample_frequency=pti_sample_frequency, ) - # embeddings.persist() + embeddings.persist() # LORA # -------------------------------------------------------------------------------- @@ -962,13 +974,13 @@ def main(): params_to_optimize = [] group_labels = [] - if len(args.placeholder_tokens) != 0: - params_to_optimize.append({ - "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), - "lr": args.learning_rate_text, - "weight_decay": 0, - }) - group_labels.append("emb") + # if len(args.placeholder_tokens) != 0: + # params_to_optimize.append({ + # "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), + # "lr": args.learning_rate_text, + # "weight_decay": 0, + # }) + # group_labels.append("emb") params_to_optimize += [ { "params": ( diff --git a/train_ti.py b/train_ti.py index 48858cc..daf8bc5 100644 --- a/train_ti.py +++ b/train_ti.py @@ -3,6 +3,7 @@ import datetime import logging from functools import partial from pathlib import Path +from typing import Union import math import torch @@ -74,6 +75,12 @@ def parse_args(): nargs='*', help="A token to use as initializer word." ) + parser.add_argument( + "--filter_tokens", + type=str, + nargs='*', + help="Tokens to filter the dataset by." + ) parser.add_argument( "--initializer_noise", type=float, @@ -538,6 +545,12 @@ def parse_args(): if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: raise ValueError("--alias_tokens must be a list with an even number of items") + if args.filter_tokens is None: + args.filter_tokens = args.placeholder_tokens.copy() + + if isinstance(args.filter_tokens, str): + args.filter_tokens = [args.filter_tokens] + if args.sequential: args.alias_tokens += [ item @@ -779,13 +792,11 @@ def main(): sample_image_size=args.sample_image_size, ) - def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): + def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): if len(placeholder_tokens) == 1: sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" - metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png" else: sample_output_dir = output_dir / "samples" - metrics_output_file = output_dir / "lr.png" placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -800,6 +811,8 @@ def main(): print(f"{i + 1}: {stats}") + filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -820,7 +833,7 @@ def main(): train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, seed=args.seed, - filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), + filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -834,7 +847,7 @@ def main(): sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.parameters(), lr=args.learning_rate, ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index cfdc504..ae85401 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -93,7 +93,7 @@ def lora_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.params + for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 720ebf3..289d6bd 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -72,7 +72,7 @@ def textual_inversion_strategy_callbacks( if use_ema: ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.token_override_embedding.params.parameters(), + text_encoder.text_model.embeddings.token_override_embedding.parameters(), inv_gamma=ema_inv_gamma, power=ema_power, max_value=ema_max_decay, @@ -84,20 +84,20 @@ def textual_inversion_strategy_callbacks( def ema_context(): if ema_embeddings is not None: return ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.token_override_embedding.params.parameters() + text_encoder.text_model.embeddings.token_override_embedding.parameters() ) else: return nullcontext() @contextmanager def on_train(epoch: int): - text_encoder.text_model.embeddings.token_override_embedding.params.train() + text_encoder.train() tokenizer.train() yield @contextmanager def on_eval(): - text_encoder.text_model.embeddings.token_override_embedding.params.eval() + text_encoder.eval() tokenizer.eval() with ema_context(): @@ -108,7 +108,7 @@ def textual_inversion_strategy_callbacks( if use_emb_decay: params = [ p - for p in text_encoder.text_model.embeddings.token_override_embedding.params + for p in text_encoder.text_model.embeddings.token_override_embedding.parameters() if p.grad is not None ] return torch.stack(params) if len(params) != 0 else None @@ -116,7 +116,7 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.params.parameters()) + ema_embeddings.step(text_encoder.text_model.embeddings.token_override_embedding.parameters()) if use_emb_decay and w is not None: lr = lrs["emb"] or lrs["0"] -- cgit v1.2.3-70-g09d2