summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 11:26:31 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 11:26:31 +0200
commit0f493e1ac8406de061861ed390f283e821180e79 (patch)
tree0186a40130f095f1a3bdaa3bf4064a5bd5d35187 /dreambooth.py
parentSmall performance improvements (diff)
downloadtextual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.gz
textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.bz2
textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.zip
Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py26
1 files changed, 20 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 4d7366c..744d1bc 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -14,12 +14,14 @@ from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
17from schedulers.scheduling_euler_a import EulerAScheduler
17from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
18from pipelines.stable_diffusion.no_check import NoCheck 19from pipelines.stable_diffusion.no_check import NoCheck
19from PIL import Image 20from PIL import Image
20from tqdm.auto import tqdm 21from tqdm.auto import tqdm
21from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
22from slugify import slugify 23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23import json 25import json
24 26
25from data.dreambooth.csv import CSVDataModule 27from data.dreambooth.csv import CSVDataModule
@@ -215,7 +217,7 @@ def parse_args():
215 parser.add_argument( 217 parser.add_argument(
216 "--sample_steps", 218 "--sample_steps",
217 type=int, 219 type=int,
218 default=80, 220 default=30,
219 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 221 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
220 ) 222 )
221 parser.add_argument( 223 parser.add_argument(
@@ -377,15 +379,16 @@ class Checkpointer:
377 samples_path = Path(self.output_dir).joinpath("samples") 379 samples_path = Path(self.output_dir).joinpath("samples")
378 380
379 unwrapped = self.accelerator.unwrap_model(self.unet) 381 unwrapped = self.accelerator.unwrap_model(self.unet)
380 pipeline = StableDiffusionPipeline( 382 scheduler = EulerAScheduler(
383 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
384 )
385
386 pipeline = VlpnStableDiffusion(
381 text_encoder=self.text_encoder, 387 text_encoder=self.text_encoder,
382 vae=self.vae, 388 vae=self.vae,
383 unet=unwrapped, 389 unet=unwrapped,
384 tokenizer=self.tokenizer, 390 tokenizer=self.tokenizer,
385 scheduler=LMSDiscreteScheduler( 391 scheduler=scheduler,
386 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
387 ),
388 safety_checker=NoCheck(),
389 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 392 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
390 ).to(self.accelerator.device) 393 ).to(self.accelerator.device)
391 pipeline.enable_attention_slicing() 394 pipeline.enable_attention_slicing()
@@ -411,6 +414,8 @@ class Checkpointer:
411 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
412 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] 415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
413 416
417 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
418
414 with self.accelerator.autocast(): 419 with self.accelerator.autocast():
415 samples = pipeline( 420 samples = pipeline(
416 prompt=prompt, 421 prompt=prompt,
@@ -420,10 +425,13 @@ class Checkpointer:
420 guidance_scale=guidance_scale, 425 guidance_scale=guidance_scale,
421 eta=eta, 426 eta=eta,
422 num_inference_steps=num_inference_steps, 427 num_inference_steps=num_inference_steps,
428 generator=generator,
423 output_type='pil' 429 output_type='pil'
424 )["sample"] 430 )["sample"]
425 431
426 all_samples += samples 432 all_samples += samples
433
434 del generator
427 del samples 435 del samples
428 436
429 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 437 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
@@ -444,6 +452,8 @@ class Checkpointer:
444 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 452 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
445 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 453 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
446 454
455 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
456
447 with self.accelerator.autocast(): 457 with self.accelerator.autocast():
448 samples = pipeline( 458 samples = pipeline(
449 prompt=prompt, 459 prompt=prompt,
@@ -452,10 +462,13 @@ class Checkpointer:
452 guidance_scale=guidance_scale, 462 guidance_scale=guidance_scale,
453 eta=eta, 463 eta=eta,
454 num_inference_steps=num_inference_steps, 464 num_inference_steps=num_inference_steps,
465 generator=generator,
455 output_type='pil' 466 output_type='pil'
456 )["sample"] 467 )["sample"]
457 468
458 all_samples += samples 469 all_samples += samples
470
471 del generator
459 del samples 472 del samples
460 473
461 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 474 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
@@ -465,6 +478,7 @@ class Checkpointer:
465 del image_grid 478 del image_grid
466 479
467 del unwrapped 480 del unwrapped
481 del scheduler
468 del pipeline 482 del pipeline
469 483
470 if torch.cuda.is_available(): 484 if torch.cuda.is_available():