diff options
author | Volpeon <git@volpeon.ink> | 2023-04-10 14:48:25 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-10 14:48:25 +0200 |
commit | c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669 (patch) | |
tree | 5b391677d29148edddda073823bda8425228be65 | |
parent | Update (diff) | |
download | textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.gz textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.tar.bz2 textual-inversion-diff-c6e97ab40fd7a4e6d53b3e9e4aa28f7dfb6de669.zip |
Randomize dataset across cycles
-rw-r--r-- | data/csv.py | 14 | ||||
-rw-r--r-- | train_lora.py | 4 | ||||
-rw-r--r-- | train_ti.py | 6 |
3 files changed, 12 insertions, 12 deletions
diff --git a/data/csv.py b/data/csv.py index 818fcd9..3af9925 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -192,7 +192,7 @@ class VlpnDataModule(): | |||
192 | valid_set_size: Optional[int] = None, | 192 | valid_set_size: Optional[int] = None, |
193 | train_set_pad: Optional[int] = None, | 193 | train_set_pad: Optional[int] = None, |
194 | valid_set_pad: Optional[int] = None, | 194 | valid_set_pad: Optional[int] = None, |
195 | seed: Optional[int] = None, | 195 | generator: Optional[torch.Generator] = None, |
196 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 196 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
197 | dtype: torch.dtype = torch.float32, | 197 | dtype: torch.dtype = torch.float32, |
198 | ): | 198 | ): |
@@ -224,10 +224,10 @@ class VlpnDataModule(): | |||
224 | self.valid_set_size = valid_set_size | 224 | self.valid_set_size = valid_set_size |
225 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size | 225 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size |
226 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size | 226 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size |
227 | self.seed = seed | ||
228 | self.filter = filter | 227 | self.filter = filter |
229 | self.batch_size = batch_size | 228 | self.batch_size = batch_size |
230 | self.dtype = dtype | 229 | self.dtype = dtype |
230 | self.generator = generator | ||
231 | 231 | ||
232 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: | 232 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
233 | tpl_image = template["image"] if "image" in template else "{}" | 233 | tpl_image = template["image"] if "image" in template else "{}" |
@@ -304,16 +304,12 @@ class VlpnDataModule(): | |||
304 | train_set_size = max(num_images - valid_set_size, 1) | 304 | train_set_size = max(num_images - valid_set_size, 1) |
305 | valid_set_size = num_images - train_set_size | 305 | valid_set_size = num_images - train_set_size |
306 | 306 | ||
307 | generator = torch.Generator(device="cpu") | ||
308 | if self.seed is not None: | ||
309 | generator = generator.manual_seed(self.seed) | ||
310 | |||
311 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) | 307 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) |
312 | 308 | ||
313 | if valid_set_size == 0: | 309 | if valid_set_size == 0: |
314 | data_train, data_val = items, items | 310 | data_train, data_val = items, items |
315 | else: | 311 | else: |
316 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 312 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) |
317 | 313 | ||
318 | data_train = self.pad_items(data_train, self.num_class_images) | 314 | data_train = self.pad_items(data_train, self.num_class_images) |
319 | 315 | ||
@@ -324,7 +320,7 @@ class VlpnDataModule(): | |||
324 | data_train, self.tokenizer, | 320 | data_train, self.tokenizer, |
325 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 321 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
326 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 322 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
327 | batch_size=self.batch_size, fill_batch=True, generator=generator, | 323 | batch_size=self.batch_size, fill_batch=True, generator=self.generator, |
328 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 324 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
329 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 325 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
330 | ) | 326 | ) |
@@ -344,7 +340,7 @@ class VlpnDataModule(): | |||
344 | data_val, self.tokenizer, | 340 | data_val, self.tokenizer, |
345 | num_buckets=self.num_buckets, progressive_buckets=True, | 341 | num_buckets=self.num_buckets, progressive_buckets=True, |
346 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 342 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
347 | batch_size=self.batch_size, generator=generator, | 343 | batch_size=self.batch_size, generator=self.generator, |
348 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 344 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
349 | ) | 345 | ) |
350 | 346 | ||
diff --git a/train_lora.py b/train_lora.py index 4bbc64e..0d8ee23 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -865,6 +865,8 @@ def main(): | |||
865 | max_grad_norm=args.max_grad_norm, | 865 | max_grad_norm=args.max_grad_norm, |
866 | ) | 866 | ) |
867 | 867 | ||
868 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | ||
869 | |||
868 | create_datamodule = partial( | 870 | create_datamodule = partial( |
869 | VlpnDataModule, | 871 | VlpnDataModule, |
870 | data_file=args.train_data_file, | 872 | data_file=args.train_data_file, |
@@ -882,8 +884,8 @@ def main(): | |||
882 | valid_set_size=args.valid_set_size, | 884 | valid_set_size=args.valid_set_size, |
883 | train_set_pad=args.train_set_pad, | 885 | train_set_pad=args.train_set_pad, |
884 | valid_set_pad=args.valid_set_pad, | 886 | valid_set_pad=args.valid_set_pad, |
885 | seed=args.seed, | ||
886 | dtype=weight_dtype, | 887 | dtype=weight_dtype, |
888 | generator=data_generator, | ||
887 | ) | 889 | ) |
888 | 890 | ||
889 | create_lr_scheduler = partial( | 891 | create_lr_scheduler = partial( |
diff --git a/train_ti.py b/train_ti.py index eb08bda..009495b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -817,6 +817,8 @@ def main(): | |||
817 | sample_image_size=args.sample_image_size, | 817 | sample_image_size=args.sample_image_size, |
818 | ) | 818 | ) |
819 | 819 | ||
820 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | ||
821 | |||
820 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 822 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): |
821 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 823 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
822 | tokenizer=tokenizer, | 824 | tokenizer=tokenizer, |
@@ -855,9 +857,9 @@ def main(): | |||
855 | valid_set_size=args.valid_set_size, | 857 | valid_set_size=args.valid_set_size, |
856 | train_set_pad=args.train_set_pad, | 858 | train_set_pad=args.train_set_pad, |
857 | valid_set_pad=args.valid_set_pad, | 859 | valid_set_pad=args.valid_set_pad, |
858 | seed=args.seed, | ||
859 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 860 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), |
860 | dtype=weight_dtype | 861 | dtype=weight_dtype, |
862 | generator=data_generator, | ||
861 | ) | 863 | ) |
862 | datamodule.setup() | 864 | datamodule.setup() |
863 | 865 | ||