summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
commitf23fd5184b8ba4ec04506495f4a61726e50756f7 (patch)
treed4c5666b291316ed95437cc1c917b03ef3b679da /dreambooth.py
parentAdded negative prompt support for training scripts (diff)
downloadtextual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip
Small perf improvements
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py89
1 files changed, 46 insertions, 43 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 5fbf172..9d6b8d6 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -13,7 +13,7 @@ import torch.utils.checkpoint
13from accelerate import Accelerator 13from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
17from schedulers.scheduling_euler_a import EulerAScheduler 17from schedulers.scheduling_euler_a import EulerAScheduler
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
19from pipelines.stable_diffusion.no_check import NoCheck 19from pipelines.stable_diffusion.no_check import NoCheck
@@ -30,6 +30,9 @@ from data.dreambooth.prompt import PromptDataset
30logger = get_logger(__name__) 30logger = get_logger(__name__)
31 31
32 32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
33def parse_args(): 36def parse_args():
34 parser = argparse.ArgumentParser( 37 parser = argparse.ArgumentParser(
35 description="Simple example of a training script." 38 description="Simple example of a training script."
@@ -346,7 +349,7 @@ class Checkpointer:
346 print("Saving model...") 349 print("Saving model...")
347 350
348 unwrapped = self.accelerator.unwrap_model(self.unet) 351 unwrapped = self.accelerator.unwrap_model(self.unet)
349 pipeline = StableDiffusionPipeline( 352 pipeline = VlpnStableDiffusion(
350 text_encoder=self.text_encoder, 353 text_encoder=self.text_encoder,
351 vae=self.vae, 354 vae=self.vae,
352 unet=self.accelerator.unwrap_model(self.unet), 355 unet=self.accelerator.unwrap_model(self.unet),
@@ -354,8 +357,6 @@ class Checkpointer:
354 scheduler=PNDMScheduler( 357 scheduler=PNDMScheduler(
355 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 358 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
356 ), 359 ),
357 safety_checker=NoCheck(),
358 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
359 ) 360 )
360 pipeline.enable_attention_slicing() 361 pipeline.enable_attention_slicing()
361 pipeline.save_pretrained(f"{self.output_dir}/model") 362 pipeline.save_pretrained(f"{self.output_dir}/model")
@@ -381,7 +382,6 @@ class Checkpointer:
381 unet=unwrapped, 382 unet=unwrapped,
382 tokenizer=self.tokenizer, 383 tokenizer=self.tokenizer,
383 scheduler=scheduler, 384 scheduler=scheduler,
384 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
385 ).to(self.accelerator.device) 385 ).to(self.accelerator.device)
386 pipeline.enable_attention_slicing() 386 pipeline.enable_attention_slicing()
387 387
@@ -459,44 +459,6 @@ def main():
459 if args.seed is not None: 459 if args.seed is not None:
460 set_seed(args.seed) 460 set_seed(args.seed)
461 461
462 if args.with_prior_preservation:
463 class_images_dir = Path(args.class_data_dir)
464 class_images_dir.mkdir(parents=True, exist_ok=True)
465 cur_class_images = len(list(class_images_dir.iterdir()))
466
467 if cur_class_images < args.num_class_images:
468 torch_dtype = torch.float32
469 if accelerator.device.type == "cuda":
470 torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision]
471
472 pipeline = StableDiffusionPipeline.from_pretrained(
473 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
474 pipeline.enable_attention_slicing()
475 pipeline.set_progress_bar_config(disable=True)
476 pipeline.to(accelerator.device)
477
478 num_new_images = args.num_class_images - cur_class_images
479 logger.info(f"Number of class images to sample: {num_new_images}.")
480
481 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
482 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
483
484 sample_dataloader = accelerator.prepare(sample_dataloader)
485
486 for example in tqdm(
487 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
488 ):
489 with accelerator.autocast():
490 images = pipeline(example["prompt"]).images
491
492 for i, image in enumerate(images):
493 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
494
495 del pipeline
496
497 if torch.cuda.is_available():
498 torch.cuda.empty_cache()
499
500 # Load the tokenizer and add the placeholder token as a additional special token 462 # Load the tokenizer and add the placeholder token as a additional special token
501 if args.tokenizer_name: 463 if args.tokenizer_name:
502 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 464 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
@@ -526,6 +488,47 @@ def main():
526 freeze_params(vae.parameters()) 488 freeze_params(vae.parameters())
527 freeze_params(text_encoder.parameters()) 489 freeze_params(text_encoder.parameters())
528 490
491 # Generate class images, if necessary
492 if args.with_prior_preservation:
493 class_images_dir = Path(args.class_data_dir)
494 class_images_dir.mkdir(parents=True, exist_ok=True)
495 cur_class_images = len(list(class_images_dir.iterdir()))
496
497 if cur_class_images < args.num_class_images:
498 scheduler = EulerAScheduler(
499 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
500 )
501
502 pipeline = VlpnStableDiffusion(
503 text_encoder=text_encoder,
504 vae=vae,
505 unet=unet,
506 tokenizer=tokenizer,
507 scheduler=scheduler,
508 ).to(accelerator.device)
509 pipeline.enable_attention_slicing()
510 pipeline.set_progress_bar_config(disable=True)
511
512 num_new_images = args.num_class_images - cur_class_images
513 logger.info(f"Number of class images to sample: {num_new_images}.")
514
515 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
516 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
517
518 sample_dataloader = accelerator.prepare(sample_dataloader)
519
520 for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process):
521 with accelerator.autocast():
522 images = pipeline(example["prompt"]).images
523
524 for i, image in enumerate(images):
525 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
526
527 del pipeline
528
529 if torch.cuda.is_available():
530 torch.cuda.empty_cache()
531
529 if args.scale_lr: 532 if args.scale_lr:
530 args.learning_rate = ( 533 args.learning_rate = (
531 args.learning_rate * args.gradient_accumulation_steps * 534 args.learning_rate * args.gradient_accumulation_steps *