From 6c072fe50b3bfc561f22e5d591212d30de3c2dd2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 12:08:16 +0200 Subject: Fixed euler_a generator argument --- data/dreambooth/csv.py | 6 ++++-- data/textual_inversion/csv.py | 8 +++++--- dreambooth.py | 8 -------- schedulers/scheduling_euler_a.py | 3 +-- textual_inversion.py | 8 -------- 5 files changed, 10 insertions(+), 23 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4087226..08ed49c 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule): identifier="*", center_crop=False, valid_set_size=None, + generator=None, collate_fn=None): super().__init__() @@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule): self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size + self.generator = generator self.collate_fn = collate_fn self.batch_size = batch_size @@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule): def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) if self.valid_set_size: - valid_set_size = math.min(valid_set_size, self.valid_set_size) + valid_set_size = min(valid_set_size, self.valid_set_size) train_set_size = len(self.data_full) - valid_set_size - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, class_data_root=self.class_data_root, class_prompt=self.class_prompt, diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index e082511..3ac57df 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py @@ -19,7 +19,8 @@ class CSVDataModule(pl.LightningDataModule): interpolation="bicubic", placeholder_token="*", center_crop=False, - valid_set_size=None): + valid_set_size=None, + generator=None): super().__init__() self.data_file = Path(data_file) @@ -35,6 +36,7 @@ class CSVDataModule(pl.LightningDataModule): self.center_crop = center_crop self.interpolation = interpolation self.valid_set_size = valid_set_size + self.generator = generator self.batch_size = batch_size @@ -48,10 +50,10 @@ class CSVDataModule(pl.LightningDataModule): def setup(self, stage=None): valid_set_size = int(len(self.data_full) * 0.2) if self.valid_set_size: - valid_set_size = math.min(valid_set_size, self.valid_set_size) + valid_set_size = min(valid_set_size, self.valid_set_size) train_set_size = len(self.data_full) - valid_set_size - self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) + self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, placeholder_token=self.placeholder_token, center_crop=self.center_crop) diff --git a/dreambooth.py b/dreambooth.py index 88cd0da..75602dc 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -414,8 +414,6 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] - generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -425,13 +423,11 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, - generator=generator, output_type='pil' )["sample"] all_samples += samples - del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) @@ -452,8 +448,6 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] - generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -462,13 +456,11 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, - generator=generator, output_type='pil' )["sample"] all_samples += samples - del generator del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index d7fea85..c6436d8 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -198,7 +198,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): timestep: int, timestep_prev: int, sample: torch.FloatTensor, - generator: None, + generator: torch.Generator = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -240,7 +240,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): sample_hat: torch.FloatTensor, sample_prev: torch.FloatTensor, derivative: torch.FloatTensor, - generator: None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ diff --git a/textual_inversion.py b/textual_inversion.py index fa6214e..285aa0a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -403,8 +403,6 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] - generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -414,13 +412,11 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, - generator=generator, output_type='pil' )["sample"] all_samples += samples - del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) @@ -441,8 +437,6 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] - generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -451,13 +445,11 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, - generator=generator, output_type='pil' )["sample"] all_samples += samples - del generator del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) -- cgit v1.2.3-70-g09d2