From b7b9f7a7fc3a2e6a027175e5a84541ca2291fbb5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 1 Jan 2023 11:36:00 +0100 Subject: Fixed accuracy calc, other improvements --- data/csv.py | 2 +- models/clip/tokenizer.py | 18 +++++++++++------- train_dreambooth.py | 30 +++++++++++++++++++++++++++++- train_ti.py | 36 +++++++++++++++++++++++++++++++++--- training/ti.py | 48 ------------------------------------------------ 5 files changed, 74 insertions(+), 60 deletions(-) delete mode 100644 training/ti.py diff --git a/data/csv.py b/data/csv.py index 803271b..af36d9e 100644 --- a/data/csv.py +++ b/data/csv.py @@ -151,7 +151,7 @@ class CSVDataModule(): num_images = len(items) - valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.1) + valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.2) valid_set_size = max(valid_set_size, 1) train_set_size = num_images - valid_set_size diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index a3e6e70..37d69a9 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -15,6 +15,10 @@ class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_map: dict[int, list[int]] = {} + self.vector_shuffle = False + + def set_use_vector_shuffle(self, enable: bool): + self.vector_shuffle = enable def add_multi_tokens(self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1) -> MultiCLIPTokenizerItem: if isinstance(new_tokens, list): @@ -42,11 +46,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): return MultiCLIPTokenizerItem(new_tokens, ids) - def expand_id(self, id: int, vector_shuffle=True): + def expand_id(self, id: int): if id in self.token_map: tokens = self.token_map[id] - if vector_shuffle and len(tokens) > 2: + if self.vector_shuffle and len(tokens) > 2: subtokens = tokens[1:-1] np.random.shuffle(subtokens) tokens = tokens[:1] + subtokens + tokens[-1:] @@ -55,21 +59,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): else: return [id] - def expand_ids(self, ids: list[int], vector_shuffle=True): + def expand_ids(self, ids: list[int]): return [ new_id for id in ids - for new_id in self.expand_id(id, vector_shuffle) + for new_id in self.expand_id(id) ] - def _call_one(self, text, *args, vector_shuffle=True, **kwargs): + def _call_one(self, text, *args, **kwargs): result = super()._call_one(text, *args, **kwargs) is_batched = isinstance(result.input_ids, (list, tuple)) and isinstance(result.input_ids[0], list) if is_batched: - result.input_ids = [self.expand_ids(batch, vector_shuffle) for batch in result.input_ids] + result.input_ids = [self.expand_ids(batch) for batch in result.input_ids] else: - result.input_ids = self.expand_ids(result.input_ids, vector_shuffle) + result.input_ids = self.expand_ids(result.input_ids) return result diff --git a/train_dreambooth.py b/train_dreambooth.py index 8fd78f1..1ebcfe3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -231,6 +231,30 @@ def parse_args(): default=None, help="Number of restart cycles in the lr scheduler (if supported)." ) + parser.add_argument( + "--lr_warmup_func", + type=str, + default="cos", + help='Choose between ["linear", "cos"]' + ) + parser.add_argument( + "--lr_warmup_exp", + type=int, + default=1, + help='If lr_warmup_func is "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_annealing_func", + type=str, + default="cos", + help='Choose between ["linear", "half_cos", "cos"]' + ) + parser.add_argument( + "--lr_annealing_exp", + type=int, + default=3, + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + ) parser.add_argument( "--use_ema", action="store_true", @@ -760,6 +784,10 @@ def main(): lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + warmup=args.lr_warmup_func, + annealing=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( @@ -913,7 +941,7 @@ def main(): else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - acc = (model_pred == latents).float().mean() + acc = (model_pred == target).float().mean() return loss, acc, bsz diff --git a/train_ti.py b/train_ti.py index 19348e5..20a3190 100644 --- a/train_ti.py +++ b/train_ti.py @@ -224,6 +224,30 @@ def parse_args(): default=None, help="Number of restart cycles in the lr scheduler." ) + parser.add_argument( + "--lr_warmup_func", + type=str, + default="cos", + help='Choose between ["linear", "cos"]' + ) + parser.add_argument( + "--lr_warmup_exp", + type=int, + default=1, + help='If lr_warmup_func is "cos", exponent to modify the function' + ) + parser.add_argument( + "--lr_annealing_func", + type=str, + default="cos", + help='Choose between ["linear", "half_cos", "cos"]' + ) + parser.add_argument( + "--lr_annealing_exp", + type=int, + default=2, + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -510,6 +534,8 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') + tokenizer.set_use_vector_shuffle(True) + vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True) @@ -559,7 +585,7 @@ def main(): ) if args.find_lr: - args.learning_rate = 1e2 + args.learning_rate = 1e3 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -706,6 +732,10 @@ def main(): lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + warmup=args.lr_warmup_func, + annealing=args.lr_annealing_func, + warmup_exp=args.lr_warmup_exp, + annealing_exp=args.lr_annealing_exp, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( @@ -796,13 +826,13 @@ def main(): else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - acc = (model_pred == latents).float().mean() + acc = (model_pred == target).float().mean() return loss, acc, bsz if args.find_lr: lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) - lr_finder.run(min_lr=1e-6, num_train_batches=1) + lr_finder.run(min_lr=1e-4) plt.savefig(basepath.joinpath("lr.png")) plt.close() diff --git a/training/ti.py b/training/ti.py deleted file mode 100644 index 031fe48..0000000 --- a/training/ti.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional - -import torch -import torch.nn as nn - -from transformers.models.clip import CLIPTextModel, CLIPTextConfig -from transformers.models.clip.modeling_clip import CLIPTextEmbeddings - - -def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): - text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) - text_encoder.text_model.embeddings = text_embeddings - - -class TrainableEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): - super().__init__(config) - - self.token_embedding = embeddings.token_embedding - self.position_embedding = embeddings.position_embedding - - self.train_indices = torch.tensor(new_ids) - - self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) - self.trainable_embedding.weight.data.zero_() - self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - ) -> torch.Tensor: - device = input_ids.device - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - - if inputs_embeds is None: - mask = torch.isin(input_ids, self.train_indices.to(device)) - inputs_embeds = self.token_embedding(input_ids) - inputs_embeds[mask] = self.trainable_embedding(input_ids)[mask] - - position_embeddings = self.position_embedding(position_ids) - embeddings = inputs_embeds + position_embeddings - - return embeddings -- cgit v1.2.3-70-g09d2