diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-28 18:08:36 +0100 |
| commit | 83725794618164210a12843381724252fdd82cc2 (patch) | |
| tree | ec29ade9891fe08dd10b5033214fc09237c2cb86 /train_lora.py | |
| parent | Improved learning rate finder (diff) | |
| download | textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.gz textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.tar.bz2 textual-inversion-diff-83725794618164210a12843381724252fdd82cc2.zip | |
Integrated updates from diffusers
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/train_lora.py b/train_lora.py index ffca304..9a42cae 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
| 27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 28 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
| 29 | from models.clip.prompt import PromptProcessor | 29 | from models.clip.prompt import PromptProcessor |
| 30 | 30 | ||
| 31 | logger = get_logger(__name__) | 31 | logger = get_logger(__name__) |
| @@ -513,11 +513,9 @@ def main(): | |||
| 513 | 513 | ||
| 514 | print(f"Training added text embeddings") | 514 | print(f"Training added text embeddings") |
| 515 | 515 | ||
| 516 | freeze_params(itertools.chain( | 516 | text_encoder.text_model.encoder.requires_grad_(False) |
| 517 | text_encoder.text_model.encoder.parameters(), | 517 | text_encoder.text_model.final_layer_norm.requires_grad_(False) |
| 518 | text_encoder.text_model.final_layer_norm.parameters(), | 518 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) |
| 519 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 520 | )) | ||
| 521 | 519 | ||
| 522 | index_fixed_tokens = torch.arange(len(tokenizer)) | 520 | index_fixed_tokens = torch.arange(len(tokenizer)) |
| 523 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 521 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] |
