summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 12:08:16 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 12:08:16 +0200
commit6c072fe50b3bfc561f22e5d591212d30de3c2dd2 (patch)
treee6dd60b5fa696d614ccc1cddb869c12c29f6ab46 /dreambooth.py
parentAssign unused images in validation dataset to train dataset (diff)
downloadtextual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.gz
textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.bz2
textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.zip
Fixed euler_a generator argument
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py8
1 files changed, 0 insertions, 8 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 88cd0da..75602dc 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -414,8 +414,6 @@ class Checkpointer:
414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] 415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
416 416
417 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
418
419 with self.accelerator.autocast(): 417 with self.accelerator.autocast():
420 samples = pipeline( 418 samples = pipeline(
421 prompt=prompt, 419 prompt=prompt,
@@ -425,13 +423,11 @@ class Checkpointer:
425 guidance_scale=guidance_scale, 423 guidance_scale=guidance_scale,
426 eta=eta, 424 eta=eta,
427 num_inference_steps=num_inference_steps, 425 num_inference_steps=num_inference_steps,
428 generator=generator,
429 output_type='pil' 426 output_type='pil'
430 )["sample"] 427 )["sample"]
431 428
432 all_samples += samples 429 all_samples += samples
433 430
434 del generator
435 del samples 431 del samples
436 432
437 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 433 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
@@ -452,8 +448,6 @@ class Checkpointer:
452 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 448 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
453 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 449 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
454 450
455 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
456
457 with self.accelerator.autocast(): 451 with self.accelerator.autocast():
458 samples = pipeline( 452 samples = pipeline(
459 prompt=prompt, 453 prompt=prompt,
@@ -462,13 +456,11 @@ class Checkpointer:
462 guidance_scale=guidance_scale, 456 guidance_scale=guidance_scale,
463 eta=eta, 457 eta=eta,
464 num_inference_steps=num_inference_steps, 458 num_inference_steps=num_inference_steps,
465 generator=generator,
466 output_type='pil' 459 output_type='pil'
467 )["sample"] 460 )["sample"]
468 461
469 all_samples += samples 462 all_samples += samples
470 463
471 del generator
472 del samples 464 del samples
473 465
474 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 466 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)