diff options
| -rw-r--r-- | train_dreambooth.py | 9 | ||||
| -rw-r--r-- | train_ti.py | 10 | ||||
| -rw-r--r-- | training/ti.py | 15 |
3 files changed, 21 insertions, 13 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 51e881a..8cb6414 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -568,9 +568,16 @@ def main(): | |||
| 568 | print(f"Training entire text encoder.") | 568 | print(f"Training entire text encoder.") |
| 569 | else: | 569 | else: |
| 570 | print(f"Training added text embeddings") | 570 | print(f"Training added text embeddings") |
| 571 | text_encoder.requires_grad_(False) | 571 | |
| 572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 572 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 573 | 573 | ||
| 574 | freeze_params(itertools.chain( | ||
| 575 | text_encoder.text_model.encoder.parameters(), | ||
| 576 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 577 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 578 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 579 | )) | ||
| 580 | |||
| 574 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 581 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 575 | 582 | ||
| 576 | if args.scale_lr: | 583 | if args.scale_lr: |
diff --git a/train_ti.py b/train_ti.py index a12b889..5f37d54 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.ti import patch_trainable_embeddings | 27 | from training.ti import patch_trainable_embeddings |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args, freeze_params |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -515,10 +515,16 @@ def main(): | |||
| 515 | 515 | ||
| 516 | vae.requires_grad_(False) | 516 | vae.requires_grad_(False) |
| 517 | unet.requires_grad_(False) | 517 | unet.requires_grad_(False) |
| 518 | text_encoder.requires_grad_(False) | ||
| 519 | 518 | ||
| 520 | patch_trainable_embeddings(text_encoder, placeholder_token_id) | 519 | patch_trainable_embeddings(text_encoder, placeholder_token_id) |
| 521 | 520 | ||
| 521 | freeze_params(itertools.chain( | ||
| 522 | text_encoder.text_model.encoder.parameters(), | ||
| 523 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 524 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 525 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 526 | )) | ||
| 527 | |||
| 522 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 528 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 523 | 529 | ||
| 524 | if args.scale_lr: | 530 | if args.scale_lr: |
diff --git a/training/ti.py b/training/ti.py index 8b2fdd6..dc33e5e 100644 --- a/training/ti.py +++ b/training/ti.py | |||
| @@ -8,26 +8,21 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings | |||
| 8 | 8 | ||
| 9 | 9 | ||
| 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): | 10 | def patch_trainable_embeddings(text_encoder: CLIPTextModel, new_ids: list[int]): |
| 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, new_ids) | 11 | text_embeddings = TrainableEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, new_ids) |
| 12 | |||
| 13 | text_embeddings.token_embedding = text_encoder.text_model.embeddings.token_embedding | ||
| 14 | text_embeddings.token_embedding.weight.requires_grad = False | ||
| 15 | |||
| 16 | text_embeddings.position_embedding = text_encoder.text_model.embeddings.position_embedding | ||
| 17 | text_embeddings.position_embedding.weight.requires_grad = False | ||
| 18 | |||
| 19 | text_encoder.text_model.embeddings = text_embeddings | 12 | text_encoder.text_model.embeddings = text_embeddings |
| 20 | 13 | ||
| 21 | 14 | ||
| 22 | class TrainableEmbeddings(CLIPTextEmbeddings): | 15 | class TrainableEmbeddings(CLIPTextEmbeddings): |
| 23 | def __init__(self, config: CLIPTextConfig, new_ids: list[int]): | 16 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, new_ids: list[int]): |
| 24 | super().__init__(config) | 17 | super().__init__(config) |
| 25 | 18 | ||
| 26 | self.train_indices = torch.tensor(new_ids) | 19 | self.train_indices = torch.tensor(new_ids) |
| 27 | 20 | ||
| 28 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) | 21 | self.trainable_embedding = nn.Embedding(self.token_embedding.num_embeddings, self.token_embedding.embedding_dim) |
| 22 | |||
| 23 | self.token_embedding = embeddings.token_embedding | ||
| 24 | self.position_embedding = embeddings.position_embedding | ||
| 29 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() | 25 | self.trainable_embedding.weight.data = self.token_embedding.weight.data.clone() |
| 30 | self.trainable_embedding.weight.requires_grad = True | ||
| 31 | 26 | ||
| 32 | def forward( | 27 | def forward( |
| 33 | self, | 28 | self, |
