summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-10 14:48:25 +0200
committerVolpeon <git@volpeon.ink>2023-04-10 14:48:25 +0200
commitc6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 (patch)
tree5b391677d29148edddda073823bda8425228be65 /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.gz
textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.bz2
textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.zip
Randomize dataset across cycles
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py
index eb08bda..009495b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -817,6 +817,8 @@ def main():
817 sample_image_size=args.sample_image_size, 817 sample_image_size=args.sample_image_size,
818 ) 818 )
819 819
820 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
821
820 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 822 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str):
821 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 823 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
822 tokenizer=tokenizer, 824 tokenizer=tokenizer,
@@ -855,9 +857,9 @@ def main():
855 valid_set_size=args.valid_set_size, 857 valid_set_size=args.valid_set_size,
856 train_set_pad=args.train_set_pad, 858 train_set_pad=args.train_set_pad,
857 valid_set_pad=args.valid_set_pad, 859 valid_set_pad=args.valid_set_pad,
858 seed=args.seed,
859 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 860 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
860 dtype=weight_dtype 861 dtype=weight_dtype,
862 generator=data_generator,
861 ) 863 )
862 datamodule.setup() 864 datamodule.setup()
863 865