diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 12:08:16 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 12:08:16 +0200 |
commit | 6c072fe50b3bfc561f22e5d591212d30de3c2dd2 (patch) | |
tree | e6dd60b5fa696d614ccc1cddb869c12c29f6ab46 /dreambooth.py | |
parent | Assign unused images in validation dataset to train dataset (diff) | |
download | textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.gz textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.bz2 textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.zip |
Fixed euler_a generator argument
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 8 |
1 files changed, 0 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py index 88cd0da..75602dc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -414,8 +414,6 @@ class Checkpointer: | |||
414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
416 | 416 | ||
417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
418 | |||
419 | with self.accelerator.autocast(): | 417 | with self.accelerator.autocast(): |
420 | samples = pipeline( | 418 | samples = pipeline( |
421 | prompt=prompt, | 419 | prompt=prompt, |
@@ -425,13 +423,11 @@ class Checkpointer: | |||
425 | guidance_scale=guidance_scale, | 423 | guidance_scale=guidance_scale, |
426 | eta=eta, | 424 | eta=eta, |
427 | num_inference_steps=num_inference_steps, | 425 | num_inference_steps=num_inference_steps, |
428 | generator=generator, | ||
429 | output_type='pil' | 426 | output_type='pil' |
430 | )["sample"] | 427 | )["sample"] |
431 | 428 | ||
432 | all_samples += samples | 429 | all_samples += samples |
433 | 430 | ||
434 | del generator | ||
435 | del samples | 431 | del samples |
436 | 432 | ||
437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 433 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
@@ -452,8 +448,6 @@ class Checkpointer: | |||
452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 448 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 449 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
454 | 450 | ||
455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
456 | |||
457 | with self.accelerator.autocast(): | 451 | with self.accelerator.autocast(): |
458 | samples = pipeline( | 452 | samples = pipeline( |
459 | prompt=prompt, | 453 | prompt=prompt, |
@@ -462,13 +456,11 @@ class Checkpointer: | |||
462 | guidance_scale=guidance_scale, | 456 | guidance_scale=guidance_scale, |
463 | eta=eta, | 457 | eta=eta, |
464 | num_inference_steps=num_inference_steps, | 458 | num_inference_steps=num_inference_steps, |
465 | generator=generator, | ||
466 | output_type='pil' | 459 | output_type='pil' |
467 | )["sample"] | 460 | )["sample"] |
468 | 461 | ||
469 | all_samples += samples | 462 | all_samples += samples |
470 | 463 | ||
471 | del generator | ||
472 | del samples | 464 | del samples |
473 | 465 | ||
474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 466 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |