diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index ffca304..9a42cae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
30 | 30 | ||
31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
@@ -513,11 +513,9 @@ def main(): | |||
513 | 513 | ||
514 | print(f"Training added text embeddings") | 514 | print(f"Training added text embeddings") |
515 | 515 | ||
516 | freeze_params(itertools.chain( | 516 | text_encoder.text_model.encoder.requires_grad_(False) |
517 | text_encoder.text_model.encoder.parameters(), | 517 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
518 | text_encoder.text_model.final_layer_norm.parameters(), | 518 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
519 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
520 | )) | ||
521 | 519 | ||
522 | index_fixed_tokens = torch.arange(len(tokenizer)) | 520 | index_fixed_tokens = torch.arange(len(tokenizer)) |
523 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 521 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] |