From 6c072fe50b3bfc561f22e5d591212d30de3c2dd2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 3 Oct 2022 12:08:16 +0200 Subject: Fixed euler_a generator argument --- dreambooth.py | 8 -------- 1 file changed, 8 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index 88cd0da..75602dc 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -414,8 +414,6 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] - generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -425,13 +423,11 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, - generator=generator, output_type='pil' )["sample"] all_samples += samples - del generator del samples image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) @@ -452,8 +448,6 @@ class Checkpointer: prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] - generator = torch.Generator(device="cuda").manual_seed(self.seed + i) - with self.accelerator.autocast(): samples = pipeline( prompt=prompt, @@ -462,13 +456,11 @@ class Checkpointer: guidance_scale=guidance_scale, eta=eta, num_inference_steps=num_inference_steps, - generator=generator, output_type='pil' )["sample"] all_samples += samples - del generator del samples image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) -- cgit v1.2.3-54-g00ecf