summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py1
-rw-r--r--training/optimization.py2
-rw-r--r--training/ti.py3
3 files changed, 3 insertions, 3 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8cb6414..e239833 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -556,7 +556,6 @@ def main():
556 text_encoder.resize_token_embeddings(len(tokenizer)) 556 text_encoder.resize_token_embeddings(len(tokenizer))
557 557
558 token_embeds = text_encoder.get_input_embeddings().weight.data 558 token_embeds = text_encoder.get_input_embeddings().weight.data
559 original_token_embeds = token_embeds.clone().to(accelerator.device)
560 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 559 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
561 560
562 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 561 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
diff --git a/training/optimization.py b/training/optimization.py
index c501ed9..3809f3b 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ from diffusers.utils import logging
6logger = logging.get_logger(__name__) 6logger = logging.get_logger(__name__)
7 7
8 8
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1): 9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.01, mid_point=0.4, last_epoch=-1):
10 """ 10 """
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
diff --git a/training/ti.py b/training/ti.py
index 1318e22..031fe48 100644
--- a/training/ti.py
+++ b/training/ti.py
@@ -22,7 +22,8 @@ class TrainableEmbeddings(CLIPTextEmbeddings):
22 self.train_indices = torch.tensor(new_ids) 22 self.train_indices = torch.tensor(new_ids)
23 23
24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) 24 self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim)
25 self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() 25 self.trainable_embedding.weight.data.zero_()
26 self.trainable_embedding.weight.data[self.train_indices] = self.token_embedding.weight.data[self.train_indices]
26 27
27 def forward( 28 def forward(
28 self, 29 self,