From 1af6c15f795b5ba4df9179d8c59c6b595040a33f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 25 Dec 2022 14:59:00 +0100 Subject: Update --- train_dreambooth.py | 1 - training/optimization.py | 2 +- training/ti.py | 3 ++- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 8cb6414..e239833 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -556,7 +556,6 @@ def main(): text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data - original_token_embeds = token_embeds.clone().to(accelerator.device) initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): diff --git a/training/optimization.py b/training/optimization.py index c501ed9..3809f3b 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -6,7 +6,7 @@ from diffusers.utils import logging logger = logging.get_logger(__name__) -def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1): +def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.01, mid_point=0.4, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. diff --git a/training/ti.py b/training/ti.py index 1318e22..031fe48 100644 --- a/training/ti.py +++ b/training/ti.py @@ -22,7 +22,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings): self.train_indices = torch.tensor(new_ids) self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) - self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() + self.trainable_embedding.weight.data.zero_() + self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices] def forward( self, -- cgit v1.2.3-70-g09d2