diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 11:26:31 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 11:26:31 +0200 |
commit | 0f493e1ac8406de061861ed390f283e821180e79 (patch) | |
tree | 0186a40130f095f1a3bdaa3bf4064a5bd5d35187 /dreambooth.py | |
parent | Small performance improvements (diff) | |
download | textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.gz textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.bz2 textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.zip |
Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 4d7366c..744d1bc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -14,12 +14,14 @@ from accelerate import Accelerator | |||
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
17 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
17 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
18 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
19 | from PIL import Image | 20 | from PIL import Image |
20 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
21 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 23 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
23 | import json | 25 | import json |
24 | 26 | ||
25 | from data.dreambooth.csv import CSVDataModule | 27 | from data.dreambooth.csv import CSVDataModule |
@@ -215,7 +217,7 @@ def parse_args(): | |||
215 | parser.add_argument( | 217 | parser.add_argument( |
216 | "--sample_steps", | 218 | "--sample_steps", |
217 | type=int, | 219 | type=int, |
218 | default=80, | 220 | default=30, |
219 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 221 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
220 | ) | 222 | ) |
221 | parser.add_argument( | 223 | parser.add_argument( |
@@ -377,15 +379,16 @@ class Checkpointer: | |||
377 | samples_path = Path(self.output_dir).joinpath("samples") | 379 | samples_path = Path(self.output_dir).joinpath("samples") |
378 | 380 | ||
379 | unwrapped = self.accelerator.unwrap_model(self.unet) | 381 | unwrapped = self.accelerator.unwrap_model(self.unet) |
380 | pipeline = StableDiffusionPipeline( | 382 | scheduler = EulerAScheduler( |
383 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
384 | ) | ||
385 | |||
386 | pipeline = VlpnStableDiffusion( | ||
381 | text_encoder=self.text_encoder, | 387 | text_encoder=self.text_encoder, |
382 | vae=self.vae, | 388 | vae=self.vae, |
383 | unet=unwrapped, | 389 | unet=unwrapped, |
384 | tokenizer=self.tokenizer, | 390 | tokenizer=self.tokenizer, |
385 | scheduler=LMSDiscreteScheduler( | 391 | scheduler=scheduler, |
386 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
387 | ), | ||
388 | safety_checker=NoCheck(), | ||
389 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 392 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
390 | ).to(self.accelerator.device) | 393 | ).to(self.accelerator.device) |
391 | pipeline.enable_attention_slicing() | 394 | pipeline.enable_attention_slicing() |
@@ -411,6 +414,8 @@ class Checkpointer: | |||
411 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
412 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
413 | 416 | ||
417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
418 | |||
414 | with self.accelerator.autocast(): | 419 | with self.accelerator.autocast(): |
415 | samples = pipeline( | 420 | samples = pipeline( |
416 | prompt=prompt, | 421 | prompt=prompt, |
@@ -420,10 +425,13 @@ class Checkpointer: | |||
420 | guidance_scale=guidance_scale, | 425 | guidance_scale=guidance_scale, |
421 | eta=eta, | 426 | eta=eta, |
422 | num_inference_steps=num_inference_steps, | 427 | num_inference_steps=num_inference_steps, |
428 | generator=generator, | ||
423 | output_type='pil' | 429 | output_type='pil' |
424 | )["sample"] | 430 | )["sample"] |
425 | 431 | ||
426 | all_samples += samples | 432 | all_samples += samples |
433 | |||
434 | del generator | ||
427 | del samples | 435 | del samples |
428 | 436 | ||
429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
@@ -444,6 +452,8 @@ class Checkpointer: | |||
444 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
445 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
446 | 454 | ||
455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
456 | |||
447 | with self.accelerator.autocast(): | 457 | with self.accelerator.autocast(): |
448 | samples = pipeline( | 458 | samples = pipeline( |
449 | prompt=prompt, | 459 | prompt=prompt, |
@@ -452,10 +462,13 @@ class Checkpointer: | |||
452 | guidance_scale=guidance_scale, | 462 | guidance_scale=guidance_scale, |
453 | eta=eta, | 463 | eta=eta, |
454 | num_inference_steps=num_inference_steps, | 464 | num_inference_steps=num_inference_steps, |
465 | generator=generator, | ||
455 | output_type='pil' | 466 | output_type='pil' |
456 | )["sample"] | 467 | )["sample"] |
457 | 468 | ||
458 | all_samples += samples | 469 | all_samples += samples |
470 | |||
471 | del generator | ||
459 | del samples | 472 | del samples |
460 | 473 | ||
461 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
@@ -465,6 +478,7 @@ class Checkpointer: | |||
465 | del image_grid | 478 | del image_grid |
466 | 479 | ||
467 | del unwrapped | 480 | del unwrapped |
481 | del scheduler | ||
468 | del pipeline | 482 | del pipeline |
469 | 483 | ||
470 | if torch.cuda.is_available(): | 484 | if torch.cuda.is_available(): |