summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 09:44:12 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 09:44:12 +0200
commit1a0161f345191d78a19eec829f9d73b2c2c72f94 (patch)
tree6d7bcc67672ebf26454b3254b4bd9d5ec7e64a16 /train_lora.py
parentFix (diff)
downloadtextual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.gz
textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.tar.bz2
textual-inversion-diff-1a0161f345191d78a19eec829f9d73b2c2c72f94.zip
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/train_lora.py b/train_lora.py
index 91bda5c..d5dde02 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,9 +13,11 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from peft import LoraConfig, LoraModel 15from peft import LoraConfig, LoraModel
16from slugify import slugify
17import transformers 16import transformers
18 17
18import numpy as np
19from slugify import slugify
20
19from util.files import load_config, load_embeddings_from_dir 21from util.files import load_config, load_embeddings_from_dir
20from data.csv import VlpnDataModule, keyword_filter 22from data.csv import VlpnDataModule, keyword_filter
21from training.functional import train, add_placeholder_tokens, get_models 23from training.functional import train, add_placeholder_tokens, get_models
@@ -873,6 +875,7 @@ def main():
873 ) 875 )
874 876
875 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 877 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
878 data_npgenerator = np.random.default_rng(args.seed)
876 879
877 create_datamodule = partial( 880 create_datamodule = partial(
878 VlpnDataModule, 881 VlpnDataModule,
@@ -893,6 +896,7 @@ def main():
893 valid_set_pad=args.valid_set_pad, 896 valid_set_pad=args.valid_set_pad,
894 dtype=weight_dtype, 897 dtype=weight_dtype,
895 generator=data_generator, 898 generator=data_generator,
899 npgenerator=data_npgenerator,
896 ) 900 )
897 901
898 create_lr_scheduler = partial( 902 create_lr_scheduler = partial(