diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 1 |
1 files changed, 0 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8cb6414..e239833 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -556,7 +556,6 @@ def main(): | |||
556 | text_encoder.resize_token_embeddings(len(tokenizer)) | 556 | text_encoder.resize_token_embeddings(len(tokenizer)) |
557 | 557 | ||
558 | token_embeds = text_encoder.get_input_embeddings().weight.data | 558 | token_embeds = text_encoder.get_input_embeddings().weight.data |
559 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
560 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 559 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
561 | 560 | ||
562 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 561 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): |