diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-10 14:48:25 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-10 14:48:25 +0200 |
| commit | c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 (patch) | |
| tree | 5b391677d29148edddda073823bda8425228be65 /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.gz textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.bz2 textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.zip | |
Randomize dataset across cycles
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_ti.py b/train_ti.py index eb08bda..009495b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -817,6 +817,8 @@ def main(): | |||
| 817 | sample_image_size=args.sample_image_size, | 817 | sample_image_size=args.sample_image_size, |
| 818 | ) | 818 | ) |
| 819 | 819 | ||
| 820 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | ||
| 821 | |||
| 820 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 822 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
| 821 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 823 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 822 | tokenizer=tokenizer, | 824 | tokenizer=tokenizer, |
| @@ -855,9 +857,9 @@ def main(): | |||
| 855 | valid_set_size=args.valid_set_size, | 857 | valid_set_size=args.valid_set_size, |
| 856 | train_set_pad=args.train_set_pad, | 858 | train_set_pad=args.train_set_pad, |
| 857 | valid_set_pad=args.valid_set_pad, | 859 | valid_set_pad=args.valid_set_pad, |
| 858 | seed=args.seed, | ||
| 859 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 860 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
| 860 | dtype=weight_dtype | 861 | dtype=weight_dtype, |
| 862 | generator=data_generator, | ||
| 861 | ) | 863 | ) |
| 862 | datamodule.setup() | 864 | datamodule.setup() |
| 863 | 865 | ||
