summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/dreambooth/csv.py6
-rw-r--r--data/textual_inversion/csv.py8
-rw-r--r--dreambooth.py8
-rw-r--r--schedulers/scheduling_euler_a.py3
-rw-r--r--textual_inversion.py8
5 files changed, 10 insertions, 23 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 4087226..08ed49c 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -22,6 +22,7 @@ class CSVDataModule(pl.LightningDataModule):
22 identifier="*", 22 identifier="*",
23 center_crop=False, 23 center_crop=False,
24 valid_set_size=None, 24 valid_set_size=None,
25 generator=None,
25 collate_fn=None): 26 collate_fn=None):
26 super().__init__() 27 super().__init__()
27 28
@@ -41,6 +42,7 @@ class CSVDataModule(pl.LightningDataModule):
41 self.center_crop = center_crop 42 self.center_crop = center_crop
42 self.interpolation = interpolation 43 self.interpolation = interpolation
43 self.valid_set_size = valid_set_size 44 self.valid_set_size = valid_set_size
45 self.generator = generator
44 self.collate_fn = collate_fn 46 self.collate_fn = collate_fn
45 self.batch_size = batch_size 47 self.batch_size = batch_size
46 48
@@ -54,10 +56,10 @@ class CSVDataModule(pl.LightningDataModule):
54 def setup(self, stage=None): 56 def setup(self, stage=None):
55 valid_set_size = int(len(self.data_full) * 0.2) 57 valid_set_size = int(len(self.data_full) * 0.2)
56 if self.valid_set_size: 58 if self.valid_set_size:
57 valid_set_size = math.min(valid_set_size, self.valid_set_size) 59 valid_set_size = min(valid_set_size, self.valid_set_size)
58 train_set_size = len(self.data_full) - valid_set_size 60 train_set_size = len(self.data_full) - valid_set_size
59 61
60 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 62 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
61 63
62 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 64 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
63 class_data_root=self.class_data_root, class_prompt=self.class_prompt, 65 class_data_root=self.class_data_root, class_prompt=self.class_prompt,
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index e082511..3ac57df 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -19,7 +19,8 @@ class CSVDataModule(pl.LightningDataModule):
19 interpolation="bicubic", 19 interpolation="bicubic",
20 placeholder_token="*", 20 placeholder_token="*",
21 center_crop=False, 21 center_crop=False,
22 valid_set_size=None): 22 valid_set_size=None,
23 generator=None):
23 super().__init__() 24 super().__init__()
24 25
25 self.data_file = Path(data_file) 26 self.data_file = Path(data_file)
@@ -35,6 +36,7 @@ class CSVDataModule(pl.LightningDataModule):
35 self.center_crop = center_crop 36 self.center_crop = center_crop
36 self.interpolation = interpolation 37 self.interpolation = interpolation
37 self.valid_set_size = valid_set_size 38 self.valid_set_size = valid_set_size
39 self.generator = generator
38 40
39 self.batch_size = batch_size 41 self.batch_size = batch_size
40 42
@@ -48,10 +50,10 @@ class CSVDataModule(pl.LightningDataModule):
48 def setup(self, stage=None): 50 def setup(self, stage=None):
49 valid_set_size = int(len(self.data_full) * 0.2) 51 valid_set_size = int(len(self.data_full) * 0.2)
50 if self.valid_set_size: 52 if self.valid_set_size:
51 valid_set_size = math.min(valid_set_size, self.valid_set_size) 53 valid_set_size = min(valid_set_size, self.valid_set_size)
52 train_set_size = len(self.data_full) - valid_set_size 54 train_set_size = len(self.data_full) - valid_set_size
53 55
54 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 56 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
55 57
56 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 58 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
57 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 59 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
diff --git a/dreambooth.py b/dreambooth.py
index 88cd0da..75602dc 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -414,8 +414,6 @@ class Checkpointer:
414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] 415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
416 416
417 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
418
419 with self.accelerator.autocast(): 417 with self.accelerator.autocast():
420 samples = pipeline( 418 samples = pipeline(
421 prompt=prompt, 419 prompt=prompt,
@@ -425,13 +423,11 @@ class Checkpointer:
425 guidance_scale=guidance_scale, 423 guidance_scale=guidance_scale,
426 eta=eta, 424 eta=eta,
427 num_inference_steps=num_inference_steps, 425 num_inference_steps=num_inference_steps,
428 generator=generator,
429 output_type='pil' 426 output_type='pil'
430 )["sample"] 427 )["sample"]
431 428
432 all_samples += samples 429 all_samples += samples
433 430
434 del generator
435 del samples 431 del samples
436 432
437 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 433 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
@@ -452,8 +448,6 @@ class Checkpointer:
452 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 448 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
453 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 449 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
454 450
455 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
456
457 with self.accelerator.autocast(): 451 with self.accelerator.autocast():
458 samples = pipeline( 452 samples = pipeline(
459 prompt=prompt, 453 prompt=prompt,
@@ -462,13 +456,11 @@ class Checkpointer:
462 guidance_scale=guidance_scale, 456 guidance_scale=guidance_scale,
463 eta=eta, 457 eta=eta,
464 num_inference_steps=num_inference_steps, 458 num_inference_steps=num_inference_steps,
465 generator=generator,
466 output_type='pil' 459 output_type='pil'
467 )["sample"] 460 )["sample"]
468 461
469 all_samples += samples 462 all_samples += samples
470 463
471 del generator
472 del samples 464 del samples
473 465
474 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 466 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index d7fea85..c6436d8 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -198,7 +198,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
198 timestep: int, 198 timestep: int,
199 timestep_prev: int, 199 timestep_prev: int,
200 sample: torch.FloatTensor, 200 sample: torch.FloatTensor,
201 generator: None, 201 generator: torch.Generator = None,
202 return_dict: bool = True, 202 return_dict: bool = True,
203 ) -> Union[SchedulerOutput, Tuple]: 203 ) -> Union[SchedulerOutput, Tuple]:
204 """ 204 """
@@ -240,7 +240,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
240 sample_hat: torch.FloatTensor, 240 sample_hat: torch.FloatTensor,
241 sample_prev: torch.FloatTensor, 241 sample_prev: torch.FloatTensor,
242 derivative: torch.FloatTensor, 242 derivative: torch.FloatTensor,
243 generator: None,
244 return_dict: bool = True, 243 return_dict: bool = True,
245 ) -> Union[SchedulerOutput, Tuple]: 244 ) -> Union[SchedulerOutput, Tuple]:
246 """ 245 """
diff --git a/textual_inversion.py b/textual_inversion.py
index fa6214e..285aa0a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -403,8 +403,6 @@ class Checkpointer:
403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] 404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
405 405
406 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
407
408 with self.accelerator.autocast(): 406 with self.accelerator.autocast():
409 samples = pipeline( 407 samples = pipeline(
410 prompt=prompt, 408 prompt=prompt,
@@ -414,13 +412,11 @@ class Checkpointer:
414 guidance_scale=guidance_scale, 412 guidance_scale=guidance_scale,
415 eta=eta, 413 eta=eta,
416 num_inference_steps=num_inference_steps, 414 num_inference_steps=num_inference_steps,
417 generator=generator,
418 output_type='pil' 415 output_type='pil'
419 )["sample"] 416 )["sample"]
420 417
421 all_samples += samples 418 all_samples += samples
422 419
423 del generator
424 del samples 420 del samples
425 421
426 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 422 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
@@ -441,8 +437,6 @@ class Checkpointer:
441 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 437 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
442 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 438 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
443 439
444 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
445
446 with self.accelerator.autocast(): 440 with self.accelerator.autocast():
447 samples = pipeline( 441 samples = pipeline(
448 prompt=prompt, 442 prompt=prompt,
@@ -451,13 +445,11 @@ class Checkpointer:
451 guidance_scale=guidance_scale, 445 guidance_scale=guidance_scale,
452 eta=eta, 446 eta=eta,
453 num_inference_steps=num_inference_steps, 447 num_inference_steps=num_inference_steps,
454 generator=generator,
455 output_type='pil' 448 output_type='pil'
456 )["sample"] 449 )["sample"]
457 450
458 all_samples += samples 451 all_samples += samples
459 452
460 del generator
461 del samples 453 del samples
462 454
463 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 455 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)