From c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 14:48:25 +0200 Subject: Randomize dataset across cycles --- train_ti.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index eb08bda..009495b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -817,6 +817,8 @@ def main(): sample_image_size=args.sample_image_size, ) + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -855,9 +857,9 @@ def main(): valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - seed=args.seed, filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), - dtype=weight_dtype + dtype=weight_dtype, + generator=data_generator, ) datamodule.setup() -- cgit v1.2.3-54-g00ecf