diff options
| -rw-r--r-- | data/dreambooth/csv.py | 6 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 8 | ||||
| -rw-r--r-- | dreambooth.py | 8 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 3 | ||||
| -rw-r--r-- | textual_inversion.py | 8 |
5 files changed, 10 insertions, 23 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 4087226..08ed49c 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 22 | identifier="*", | 22 | identifier="*", |
| 23 | center_crop=False, | 23 | center_crop=False, |
| 24 | valid_set_size=None, | 24 | valid_set_size=None, |
| 25 | generator=None, | ||
| 25 | collate_fn=None): | 26 | collate_fn=None): |
| 26 | super().__init__() | 27 | super().__init__() |
| 27 | 28 | ||
| @@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 41 | self.center_crop = center_crop | 42 | self.center_crop = center_crop |
| 42 | self.interpolation = interpolation | 43 | self.interpolation = interpolation |
| 43 | self.valid_set_size = valid_set_size | 44 | self.valid_set_size = valid_set_size |
| 45 | self.generator = generator | ||
| 44 | self.collate_fn = collate_fn | 46 | self.collate_fn = collate_fn |
| 45 | self.batch_size = batch_size | 47 | self.batch_size = batch_size |
| 46 | 48 | ||
| @@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 54 | def setup(self, stage=None): | 56 | def setup(self, stage=None): |
| 55 | valid_set_size = int(len(self.data_full) * 0.2) | 57 | valid_set_size = int(len(self.data_full) * 0.2) |
| 56 | if self.valid_set_size: | 58 | if self.valid_set_size: |
| 57 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | 59 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 58 | train_set_size = len(self.data_full) - valid_set_size | 60 | train_set_size = len(self.data_full) - valid_set_size |
| 59 | 61 | ||
| 60 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 62 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
| 61 | 63 | ||
| 62 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 64 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, |
| 63 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 65 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index e082511..3ac57df 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -19,7 +19,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 19 | interpolation="bicubic", | 19 | interpolation="bicubic", |
| 20 | placeholder_token="*", | 20 | placeholder_token="*", |
| 21 | center_crop=False, | 21 | center_crop=False, |
| 22 | valid_set_size=None): | 22 | valid_set_size=None, |
| 23 | generator=None): | ||
| 23 | super().__init__() | 24 | super().__init__() |
| 24 | 25 | ||
| 25 | self.data_file = Path(data_file) | 26 | self.data_file = Path(data_file) |
| @@ -35,6 +36,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 35 | self.center_crop = center_crop | 36 | self.center_crop = center_crop |
| 36 | self.interpolation = interpolation | 37 | self.interpolation = interpolation |
| 37 | self.valid_set_size = valid_set_size | 38 | self.valid_set_size = valid_set_size |
| 39 | self.generator = generator | ||
| 38 | 40 | ||
| 39 | self.batch_size = batch_size | 41 | self.batch_size = batch_size |
| 40 | 42 | ||
| @@ -48,10 +50,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 48 | def setup(self, stage=None): | 50 | def setup(self, stage=None): |
| 49 | valid_set_size = int(len(self.data_full) * 0.2) | 51 | valid_set_size = int(len(self.data_full) * 0.2) |
| 50 | if self.valid_set_size: | 52 | if self.valid_set_size: |
| 51 | valid_set_size = math.min(valid_set_size, self.valid_set_size) | 53 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 52 | train_set_size = len(self.data_full) - valid_set_size | 54 | train_set_size = len(self.data_full) - valid_set_size |
| 53 | 55 | ||
| 54 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 56 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
| 55 | 57 | ||
| 56 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 58 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 57 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 59 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
diff --git a/dreambooth.py b/dreambooth.py index 88cd0da..75602dc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -414,8 +414,6 @@ class Checkpointer: | |||
| 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 416 | 416 | ||
| 417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 418 | |||
| 419 | with self.accelerator.autocast(): | 417 | with self.accelerator.autocast(): |
| 420 | samples = pipeline( | 418 | samples = pipeline( |
| 421 | prompt=prompt, | 419 | prompt=prompt, |
| @@ -425,13 +423,11 @@ class Checkpointer: | |||
| 425 | guidance_scale=guidance_scale, | 423 | guidance_scale=guidance_scale, |
| 426 | eta=eta, | 424 | eta=eta, |
| 427 | num_inference_steps=num_inference_steps, | 425 | num_inference_steps=num_inference_steps, |
| 428 | generator=generator, | ||
| 429 | output_type='pil' | 426 | output_type='pil' |
| 430 | )["sample"] | 427 | )["sample"] |
| 431 | 428 | ||
| 432 | all_samples += samples | 429 | all_samples += samples |
| 433 | 430 | ||
| 434 | del generator | ||
| 435 | del samples | 431 | del samples |
| 436 | 432 | ||
| 437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 433 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| @@ -452,8 +448,6 @@ class Checkpointer: | |||
| 452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 448 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 449 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 454 | 450 | ||
| 455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 456 | |||
| 457 | with self.accelerator.autocast(): | 451 | with self.accelerator.autocast(): |
| 458 | samples = pipeline( | 452 | samples = pipeline( |
| 459 | prompt=prompt, | 453 | prompt=prompt, |
| @@ -462,13 +456,11 @@ class Checkpointer: | |||
| 462 | guidance_scale=guidance_scale, | 456 | guidance_scale=guidance_scale, |
| 463 | eta=eta, | 457 | eta=eta, |
| 464 | num_inference_steps=num_inference_steps, | 458 | num_inference_steps=num_inference_steps, |
| 465 | generator=generator, | ||
| 466 | output_type='pil' | 459 | output_type='pil' |
| 467 | )["sample"] | 460 | )["sample"] |
| 468 | 461 | ||
| 469 | all_samples += samples | 462 | all_samples += samples |
| 470 | 463 | ||
| 471 | del generator | ||
| 472 | del samples | 464 | del samples |
| 473 | 465 | ||
| 474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 466 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index d7fea85..c6436d8 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -198,7 +198,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 198 | timestep: int, | 198 | timestep: int, |
| 199 | timestep_prev: int, | 199 | timestep_prev: int, |
| 200 | sample: torch.FloatTensor, | 200 | sample: torch.FloatTensor, |
| 201 | generator: None, | 201 | generator: torch.Generator = None, |
| 202 | return_dict: bool = True, | 202 | return_dict: bool = True, |
| 203 | ) -> Union[SchedulerOutput, Tuple]: | 203 | ) -> Union[SchedulerOutput, Tuple]: |
| 204 | """ | 204 | """ |
| @@ -240,7 +240,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 240 | sample_hat: torch.FloatTensor, | 240 | sample_hat: torch.FloatTensor, |
| 241 | sample_prev: torch.FloatTensor, | 241 | sample_prev: torch.FloatTensor, |
| 242 | derivative: torch.FloatTensor, | 242 | derivative: torch.FloatTensor, |
| 243 | generator: None, | ||
| 244 | return_dict: bool = True, | 243 | return_dict: bool = True, |
| 245 | ) -> Union[SchedulerOutput, Tuple]: | 244 | ) -> Union[SchedulerOutput, Tuple]: |
| 246 | """ | 245 | """ |
diff --git a/textual_inversion.py b/textual_inversion.py index fa6214e..285aa0a 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -403,8 +403,6 @@ class Checkpointer: | |||
| 403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 405 | 405 | ||
| 406 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 407 | |||
| 408 | with self.accelerator.autocast(): | 406 | with self.accelerator.autocast(): |
| 409 | samples = pipeline( | 407 | samples = pipeline( |
| 410 | prompt=prompt, | 408 | prompt=prompt, |
| @@ -414,13 +412,11 @@ class Checkpointer: | |||
| 414 | guidance_scale=guidance_scale, | 412 | guidance_scale=guidance_scale, |
| 415 | eta=eta, | 413 | eta=eta, |
| 416 | num_inference_steps=num_inference_steps, | 414 | num_inference_steps=num_inference_steps, |
| 417 | generator=generator, | ||
| 418 | output_type='pil' | 415 | output_type='pil' |
| 419 | )["sample"] | 416 | )["sample"] |
| 420 | 417 | ||
| 421 | all_samples += samples | 418 | all_samples += samples |
| 422 | 419 | ||
| 423 | del generator | ||
| 424 | del samples | 420 | del samples |
| 425 | 421 | ||
| 426 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 422 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| @@ -441,8 +437,6 @@ class Checkpointer: | |||
| 441 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 437 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 442 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 438 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 443 | 439 | ||
| 444 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 445 | |||
| 446 | with self.accelerator.autocast(): | 440 | with self.accelerator.autocast(): |
| 447 | samples = pipeline( | 441 | samples = pipeline( |
| 448 | prompt=prompt, | 442 | prompt=prompt, |
| @@ -451,13 +445,11 @@ class Checkpointer: | |||
| 451 | guidance_scale=guidance_scale, | 445 | guidance_scale=guidance_scale, |
| 452 | eta=eta, | 446 | eta=eta, |
| 453 | num_inference_steps=num_inference_steps, | 447 | num_inference_steps=num_inference_steps, |
| 454 | generator=generator, | ||
| 455 | output_type='pil' | 448 | output_type='pil' |
| 456 | )["sample"] | 449 | )["sample"] |
| 457 | 450 | ||
| 458 | all_samples += samples | 451 | all_samples += samples |
| 459 | 452 | ||
| 460 | del generator | ||
| 461 | del samples | 453 | del samples |
| 462 | 454 | ||
| 463 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 455 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
