diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-03 11:26:31 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-03 11:26:31 +0200 |
| commit | 0f493e1ac8406de061861ed390f283e821180e79 (patch) | |
| tree | 0186a40130f095f1a3bdaa3bf4064a5bd5d35187 /dreambooth.py | |
| parent | Small performance improvements (diff) | |
| download | textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.gz textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.bz2 textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.zip | |
Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 4d7366c..744d1bc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -14,12 +14,14 @@ from accelerate import Accelerator | |||
| 14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
| 15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 17 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 17 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 18 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
| 19 | from PIL import Image | 20 | from PIL import Image |
| 20 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 21 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 23 | import json | 25 | import json |
| 24 | 26 | ||
| 25 | from data.dreambooth.csv import CSVDataModule | 27 | from data.dreambooth.csv import CSVDataModule |
| @@ -215,7 +217,7 @@ def parse_args(): | |||
| 215 | parser.add_argument( | 217 | parser.add_argument( |
| 216 | "--sample_steps", | 218 | "--sample_steps", |
| 217 | type=int, | 219 | type=int, |
| 218 | default=80, | 220 | default=30, |
| 219 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 221 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 220 | ) | 222 | ) |
| 221 | parser.add_argument( | 223 | parser.add_argument( |
| @@ -377,15 +379,16 @@ class Checkpointer: | |||
| 377 | samples_path = Path(self.output_dir).joinpath("samples") | 379 | samples_path = Path(self.output_dir).joinpath("samples") |
| 378 | 380 | ||
| 379 | unwrapped = self.accelerator.unwrap_model(self.unet) | 381 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 380 | pipeline = StableDiffusionPipeline( | 382 | scheduler = EulerAScheduler( |
| 383 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 384 | ) | ||
| 385 | |||
| 386 | pipeline = VlpnStableDiffusion( | ||
| 381 | text_encoder=self.text_encoder, | 387 | text_encoder=self.text_encoder, |
| 382 | vae=self.vae, | 388 | vae=self.vae, |
| 383 | unet=unwrapped, | 389 | unet=unwrapped, |
| 384 | tokenizer=self.tokenizer, | 390 | tokenizer=self.tokenizer, |
| 385 | scheduler=LMSDiscreteScheduler( | 391 | scheduler=scheduler, |
| 386 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 387 | ), | ||
| 388 | safety_checker=NoCheck(), | ||
| 389 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 392 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
| 390 | ).to(self.accelerator.device) | 393 | ).to(self.accelerator.device) |
| 391 | pipeline.enable_attention_slicing() | 394 | pipeline.enable_attention_slicing() |
| @@ -411,6 +414,8 @@ class Checkpointer: | |||
| 411 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 412 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 413 | 416 | ||
| 417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 418 | |||
| 414 | with self.accelerator.autocast(): | 419 | with self.accelerator.autocast(): |
| 415 | samples = pipeline( | 420 | samples = pipeline( |
| 416 | prompt=prompt, | 421 | prompt=prompt, |
| @@ -420,10 +425,13 @@ class Checkpointer: | |||
| 420 | guidance_scale=guidance_scale, | 425 | guidance_scale=guidance_scale, |
| 421 | eta=eta, | 426 | eta=eta, |
| 422 | num_inference_steps=num_inference_steps, | 427 | num_inference_steps=num_inference_steps, |
| 428 | generator=generator, | ||
| 423 | output_type='pil' | 429 | output_type='pil' |
| 424 | )["sample"] | 430 | )["sample"] |
| 425 | 431 | ||
| 426 | all_samples += samples | 432 | all_samples += samples |
| 433 | |||
| 434 | del generator | ||
| 427 | del samples | 435 | del samples |
| 428 | 436 | ||
| 429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| @@ -444,6 +452,8 @@ class Checkpointer: | |||
| 444 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 445 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 446 | 454 | ||
| 455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 456 | |||
| 447 | with self.accelerator.autocast(): | 457 | with self.accelerator.autocast(): |
| 448 | samples = pipeline( | 458 | samples = pipeline( |
| 449 | prompt=prompt, | 459 | prompt=prompt, |
| @@ -452,10 +462,13 @@ class Checkpointer: | |||
| 452 | guidance_scale=guidance_scale, | 462 | guidance_scale=guidance_scale, |
| 453 | eta=eta, | 463 | eta=eta, |
| 454 | num_inference_steps=num_inference_steps, | 464 | num_inference_steps=num_inference_steps, |
| 465 | generator=generator, | ||
| 455 | output_type='pil' | 466 | output_type='pil' |
| 456 | )["sample"] | 467 | )["sample"] |
| 457 | 468 | ||
| 458 | all_samples += samples | 469 | all_samples += samples |
| 470 | |||
| 471 | del generator | ||
| 459 | del samples | 472 | del samples |
| 460 | 473 | ||
| 461 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
| @@ -465,6 +478,7 @@ class Checkpointer: | |||
| 465 | del image_grid | 478 | del image_grid |
| 466 | 479 | ||
| 467 | del unwrapped | 480 | del unwrapped |
| 481 | del scheduler | ||
| 468 | del pipeline | 482 | del pipeline |
| 469 | 483 | ||
| 470 | if torch.cuda.is_available(): | 484 | if torch.cuda.is_available(): |
