From 12b9aca96a36dd77a6b2b99bbc1743d87a7ce733 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 21:00:29 +0200 Subject: Update --- models/clip/embeddings.py | 4 ++-- models/sparse.py | 11 ++++++++--- train_dreambooth.py | 4 ++-- train_lora.py | 6 +++--- train_ti.py | 6 +++--- training/strategy/dreambooth.py | 3 ++- 6 files changed, 20 insertions(+), 14 deletions(-) diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 8c3c6d4..afb7430 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -79,8 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def save_embed(self, input_ids: list[int], filename: Path): save_file({"embed": self.get_embed(input_ids)}, filename) - def persist(self): - self.token_embedding.persist() + def persist(self, clear=False): + self.token_embedding.persist(clear) def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): diff --git a/models/sparse.py b/models/sparse.py index e5897c9..55c9837 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -89,10 +89,15 @@ class SparseEmbedding(nn.Embedding): return weights - def persist(self): + def persist(self, clear=False): self.weight.data += self.get_weights(torch.arange(self.trainable_ids.shape[0])) - self.trainable_ids[:] = -1 - self.trainable = nn.ParameterList() + + if clear: + self.trainable_ids[:] = -1 + self.trainable = nn.ParameterList() + else: + for param in self.trainable: + param.zero_() def reset_parameters(self): nn.Embedding.reset_parameters(self) diff --git a/train_dreambooth.py b/train_dreambooth.py index beb65fc..929310b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -661,7 +661,7 @@ def main(): placeholder_tokens=alias_placeholder_tokens, initializer_tokens=alias_initializer_tokens, ) - embeddings.persist() + embeddings.persist(True) print( f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" ) @@ -682,7 +682,7 @@ def main(): f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" ) - embeddings.persist() + embeddings.persist(True) if len(args.placeholder_tokens) != 0: placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( diff --git a/train_lora.py b/train_lora.py index 2a43252..eeac81f 100644 --- a/train_lora.py +++ b/train_lora.py @@ -777,7 +777,7 @@ def main(): placeholder_tokens=alias_placeholder_tokens, initializer_tokens=alias_initializer_tokens, ) - embeddings.persist() + embeddings.persist(True) print( f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" ) @@ -806,7 +806,7 @@ def main(): if args.train_dir_embeddings: print("Training embeddings from embeddings dir") else: - embeddings.persist() + embeddings.persist(True) if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: embeddings = ensure_embeddings() @@ -1117,7 +1117,7 @@ def main(): no_val=True, ) - embeddings.persist() + embeddings.persist(True) # LORA # -------------------------------------------------------------------------------- diff --git a/train_ti.py b/train_ti.py index 89f4113..1d0cb6f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -691,7 +691,7 @@ def main(): placeholder_tokens=alias_placeholder_tokens, initializer_tokens=alias_initializer_tokens, ) - embeddings.persist() + embeddings.persist(True) print( f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" ) @@ -712,7 +712,7 @@ def main(): args.placeholder_tokens = added_tokens print("Training embeddings from embeddings dir") else: - embeddings.persist() + embeddings.persist(True) if args.scale_lr: args.learning_rate = ( @@ -1067,7 +1067,7 @@ def main(): args.train_data_template, ): run(i, [placeholder_token], [initializer_token], num_vectors, data_template) - embeddings.persist() + embeddings.persist(True) if __name__ == "__main__": diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index bd853e2..3d1abf7 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -98,7 +98,6 @@ def dreambooth_strategy_callbacks( if cycle < train_text_encoder_cycles: text_encoder.train() - tokenizer.train() yield @@ -155,6 +154,8 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) + text_encoder_.text_model.embeddings.persist(False) + with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, -- cgit v1.2.3-70-g09d2