diff options
author | Volpeon <git@volpeon.ink> | 2023-04-10 14:48:25 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-10 14:48:25 +0200 |
commit | c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 (patch) | |
tree | 5b391677d29148edddda073823bda8425228be65 /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.gz textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.bz2 textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.zip |
Randomize dataset across cycles
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index eb08bda..009495b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -817,6 +817,8 @@ def main(): | |||
817 | sample_image_size=args.sample_image_size, | 817 | sample_image_size=args.sample_image_size, |
818 | ) | 818 | ) |
819 | 819 | ||
820 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | ||
821 | |||
820 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 822 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
821 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 823 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
822 | tokenizer=tokenizer, | 824 | tokenizer=tokenizer, |
@@ -855,9 +857,9 @@ def main(): | |||
855 | valid_set_size=args.valid_set_size, | 857 | valid_set_size=args.valid_set_size, |
856 | train_set_pad=args.train_set_pad, | 858 | train_set_pad=args.train_set_pad, |
857 | valid_set_pad=args.valid_set_pad, | 859 | valid_set_pad=args.valid_set_pad, |
858 | seed=args.seed, | ||
859 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 860 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
860 | dtype=weight_dtype | 861 | dtype=weight_dtype, |
862 | generator=data_generator, | ||
861 | ) | 863 | ) |
862 | datamodule.setup() | 864 | datamodule.setup() |
863 | 865 | ||