diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 8 |
1 files changed, 0 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index 88cd0da..75602dc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -414,8 +414,6 @@ class Checkpointer: | |||
| 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 416 | 416 | ||
| 417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 418 | |||
| 419 | with self.accelerator.autocast(): | 417 | with self.accelerator.autocast(): |
| 420 | samples = pipeline( | 418 | samples = pipeline( |
| 421 | prompt=prompt, | 419 | prompt=prompt, |
| @@ -425,13 +423,11 @@ class Checkpointer: | |||
| 425 | guidance_scale=guidance_scale, | 423 | guidance_scale=guidance_scale, |
| 426 | eta=eta, | 424 | eta=eta, |
| 427 | num_inference_steps=num_inference_steps, | 425 | num_inference_steps=num_inference_steps, |
| 428 | generator=generator, | ||
| 429 | output_type='pil' | 426 | output_type='pil' |
| 430 | )["sample"] | 427 | )["sample"] |
| 431 | 428 | ||
| 432 | all_samples += samples | 429 | all_samples += samples |
| 433 | 430 | ||
| 434 | del generator | ||
| 435 | del samples | 431 | del samples |
| 436 | 432 | ||
| 437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 433 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| @@ -452,8 +448,6 @@ class Checkpointer: | |||
| 452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 448 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 449 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 454 | 450 | ||
| 455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 456 | |||
| 457 | with self.accelerator.autocast(): | 451 | with self.accelerator.autocast(): |
| 458 | samples = pipeline( | 452 | samples = pipeline( |
| 459 | prompt=prompt, | 453 | prompt=prompt, |
| @@ -462,13 +456,11 @@ class Checkpointer: | |||
| 462 | guidance_scale=guidance_scale, | 456 | guidance_scale=guidance_scale, |
| 463 | eta=eta, | 457 | eta=eta, |
| 464 | num_inference_steps=num_inference_steps, | 458 | num_inference_steps=num_inference_steps, |
| 465 | generator=generator, | ||
| 466 | output_type='pil' | 459 | output_type='pil' |
| 467 | )["sample"] | 460 | )["sample"] |
| 468 | 461 | ||
| 469 | all_samples += samples | 462 | all_samples += samples |
| 470 | 463 | ||
| 471 | del generator | ||
| 472 | del samples | 464 | del samples |
| 473 | 465 | ||
| 474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 466 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
