From c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 10 Apr 2023 14:48:25 +0200 Subject: Randomize dataset across cycles --- data/csv.py | 14 +++++--------- train_lora.py | 4 +++- train_ti.py | 6 ++++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/data/csv.py b/data/csv.py index 818fcd9..3af9925 100644 --- a/data/csv.py +++ b/data/csv.py @@ -192,7 +192,7 @@ class VlpnDataModule(): valid_set_size: Optional[int] = None, train_set_pad: Optional[int] = None, valid_set_pad: Optional[int] = None, - seed: Optional[int] = None, + generator: Optional[torch.Generator] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, ): @@ -224,10 +224,10 @@ class VlpnDataModule(): self.valid_set_size = valid_set_size self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size - self.seed = seed self.filter = filter self.batch_size = batch_size self.dtype = dtype + self.generator = generator def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: tpl_image = template["image"] if "image" in template else "{}" @@ -304,16 +304,12 @@ class VlpnDataModule(): train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size - generator = torch.Generator(device="cpu") - if self.seed is not None: - generator = generator.manual_seed(self.seed) - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) if valid_set_size == 0: data_train, data_val = items, items else: - data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) + data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) data_train = self.pad_items(data_train, self.num_class_images) @@ -324,7 +320,7 @@ class VlpnDataModule(): data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, fill_batch=True, generator=generator, + batch_size=self.batch_size, fill_batch=True, generator=self.generator, size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) @@ -344,7 +340,7 @@ class VlpnDataModule(): data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, generator=generator, + batch_size=self.batch_size, generator=self.generator, size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, ) diff --git a/train_lora.py b/train_lora.py index 4bbc64e..0d8ee23 100644 --- a/train_lora.py +++ b/train_lora.py @@ -865,6 +865,8 @@ def main(): max_grad_norm=args.max_grad_norm, ) + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + create_datamodule = partial( VlpnDataModule, data_file=args.train_data_file, @@ -882,8 +884,8 @@ def main(): valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - seed=args.seed, dtype=weight_dtype, + generator=data_generator, ) create_lr_scheduler = partial( diff --git a/train_ti.py b/train_ti.py index eb08bda..009495b 100644 --- a/train_ti.py +++ b/train_ti.py @@ -817,6 +817,8 @@ def main(): sample_image_size=args.sample_image_size, ) + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, @@ -855,9 +857,9 @@ def main(): valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - seed=args.seed, filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), - dtype=weight_dtype + dtype=weight_dtype, + generator=data_generator, ) datamodule.setup() -- cgit v1.2.3-70-g09d2