summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 12:08:16 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 12:08:16 +0200
commit6c072fe50b3bfc561f22e5d591212d30de3c2dd2 (patch)
treee6dd60b5fa696d614ccc1cddb869c12c29f6ab46 /textual_inversion.py
parentAssign unused images in validation dataset to train dataset (diff)
downloadtextual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.gz
textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.tar.bz2
textual-inversion-diff-6c072fe50b3bfc561f22e5d591212d30de3c2dd2.zip
Fixed euler_a generator argument
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py8
1 files changed, 0 insertions, 8 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index fa6214e..285aa0a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -403,8 +403,6 @@ class Checkpointer:
403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] 404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
405 405
406 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
407
408 with self.accelerator.autocast(): 406 with self.accelerator.autocast():
409 samples = pipeline( 407 samples = pipeline(
410 prompt=prompt, 408 prompt=prompt,
@@ -414,13 +412,11 @@ class Checkpointer:
414 guidance_scale=guidance_scale, 412 guidance_scale=guidance_scale,
415 eta=eta, 413 eta=eta,
416 num_inference_steps=num_inference_steps, 414 num_inference_steps=num_inference_steps,
417 generator=generator,
418 output_type='pil' 415 output_type='pil'
419 )["sample"] 416 )["sample"]
420 417
421 all_samples += samples 418 all_samples += samples
422 419
423 del generator
424 del samples 420 del samples
425 421
426 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 422 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
@@ -441,8 +437,6 @@ class Checkpointer:
441 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 437 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
442 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 438 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
443 439
444 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
445
446 with self.accelerator.autocast(): 440 with self.accelerator.autocast():
447 samples = pipeline( 441 samples = pipeline(
448 prompt=prompt, 442 prompt=prompt,
@@ -451,13 +445,11 @@ class Checkpointer:
451 guidance_scale=guidance_scale, 445 guidance_scale=guidance_scale,
452 eta=eta, 446 eta=eta,
453 num_inference_steps=num_inference_steps, 447 num_inference_steps=num_inference_steps,
454 generator=generator,
455 output_type='pil' 448 output_type='pil'
456 )["sample"] 449 )["sample"]
457 450
458 all_samples += samples 451 all_samples += samples
459 452
460 del generator
461 del samples 453 del samples
462 454
463 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 455 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)