diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 89 |
1 files changed, 46 insertions, 43 deletions
diff --git a/dreambooth.py b/dreambooth.py index 5fbf172..9d6b8d6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -13,7 +13,7 @@ import torch.utils.checkpoint | |||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
17 | from schedulers.scheduling_euler_a import EulerAScheduler | 17 | from schedulers.scheduling_euler_a import EulerAScheduler |
18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
19 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
@@ -30,6 +30,9 @@ from data.dreambooth.prompt import PromptDataset | |||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
31 | 31 | ||
32 | 32 | ||
33 | torch.backends.cuda.matmul.allow_tf32 = True | ||
34 | |||
35 | |||
33 | def parse_args(): | 36 | def parse_args(): |
34 | parser = argparse.ArgumentParser( | 37 | parser = argparse.ArgumentParser( |
35 | description="Simple example of a training script." | 38 | description="Simple example of a training script." |
@@ -346,7 +349,7 @@ class Checkpointer: | |||
346 | print("Saving model...") | 349 | print("Saving model...") |
347 | 350 | ||
348 | unwrapped = self.accelerator.unwrap_model(self.unet) | 351 | unwrapped = self.accelerator.unwrap_model(self.unet) |
349 | pipeline = StableDiffusionPipeline( | 352 | pipeline = VlpnStableDiffusion( |
350 | text_encoder=self.text_encoder, | 353 | text_encoder=self.text_encoder, |
351 | vae=self.vae, | 354 | vae=self.vae, |
352 | unet=self.accelerator.unwrap_model(self.unet), | 355 | unet=self.accelerator.unwrap_model(self.unet), |
@@ -354,8 +357,6 @@ class Checkpointer: | |||
354 | scheduler=PNDMScheduler( | 357 | scheduler=PNDMScheduler( |
355 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 358 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
356 | ), | 359 | ), |
357 | safety_checker=NoCheck(), | ||
358 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
359 | ) | 360 | ) |
360 | pipeline.enable_attention_slicing() | 361 | pipeline.enable_attention_slicing() |
361 | pipeline.save_pretrained(f"{self.output_dir}/model") | 362 | pipeline.save_pretrained(f"{self.output_dir}/model") |
@@ -381,7 +382,6 @@ class Checkpointer: | |||
381 | unet=unwrapped, | 382 | unet=unwrapped, |
382 | tokenizer=self.tokenizer, | 383 | tokenizer=self.tokenizer, |
383 | scheduler=scheduler, | 384 | scheduler=scheduler, |
384 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
385 | ).to(self.accelerator.device) | 385 | ).to(self.accelerator.device) |
386 | pipeline.enable_attention_slicing() | 386 | pipeline.enable_attention_slicing() |
387 | 387 | ||
@@ -459,44 +459,6 @@ def main(): | |||
459 | if args.seed is not None: | 459 | if args.seed is not None: |
460 | set_seed(args.seed) | 460 | set_seed(args.seed) |
461 | 461 | ||
462 | if args.with_prior_preservation: | ||
463 | class_images_dir = Path(args.class_data_dir) | ||
464 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
465 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
466 | |||
467 | if cur_class_images < args.num_class_images: | ||
468 | torch_dtype = torch.float32 | ||
469 | if accelerator.device.type == "cuda": | ||
470 | torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] | ||
471 | |||
472 | pipeline = StableDiffusionPipeline.from_pretrained( | ||
473 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | ||
474 | pipeline.enable_attention_slicing() | ||
475 | pipeline.set_progress_bar_config(disable=True) | ||
476 | pipeline.to(accelerator.device) | ||
477 | |||
478 | num_new_images = args.num_class_images - cur_class_images | ||
479 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
480 | |||
481 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
482 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
483 | |||
484 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
485 | |||
486 | for example in tqdm( | ||
487 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process | ||
488 | ): | ||
489 | with accelerator.autocast(): | ||
490 | images = pipeline(example["prompt"]).images | ||
491 | |||
492 | for i, image in enumerate(images): | ||
493 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
494 | |||
495 | del pipeline | ||
496 | |||
497 | if torch.cuda.is_available(): | ||
498 | torch.cuda.empty_cache() | ||
499 | |||
500 | # Load the tokenizer and add the placeholder token as a additional special token | 462 | # Load the tokenizer and add the placeholder token as a additional special token |
501 | if args.tokenizer_name: | 463 | if args.tokenizer_name: |
502 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 464 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
@@ -526,6 +488,47 @@ def main(): | |||
526 | freeze_params(vae.parameters()) | 488 | freeze_params(vae.parameters()) |
527 | freeze_params(text_encoder.parameters()) | 489 | freeze_params(text_encoder.parameters()) |
528 | 490 | ||
491 | # Generate class images, if necessary | ||
492 | if args.with_prior_preservation: | ||
493 | class_images_dir = Path(args.class_data_dir) | ||
494 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
495 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
496 | |||
497 | if cur_class_images < args.num_class_images: | ||
498 | scheduler = EulerAScheduler( | ||
499 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
500 | ) | ||
501 | |||
502 | pipeline = VlpnStableDiffusion( | ||
503 | text_encoder=text_encoder, | ||
504 | vae=vae, | ||
505 | unet=unet, | ||
506 | tokenizer=tokenizer, | ||
507 | scheduler=scheduler, | ||
508 | ).to(accelerator.device) | ||
509 | pipeline.enable_attention_slicing() | ||
510 | pipeline.set_progress_bar_config(disable=True) | ||
511 | |||
512 | num_new_images = args.num_class_images - cur_class_images | ||
513 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
514 | |||
515 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
516 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
517 | |||
518 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
519 | |||
520 | for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): | ||
521 | with accelerator.autocast(): | ||
522 | images = pipeline(example["prompt"]).images | ||
523 | |||
524 | for i, image in enumerate(images): | ||
525 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
526 | |||
527 | del pipeline | ||
528 | |||
529 | if torch.cuda.is_available(): | ||
530 | torch.cuda.empty_cache() | ||
531 | |||
529 | if args.scale_lr: | 532 | if args.scale_lr: |
530 | args.learning_rate = ( | 533 | args.learning_rate = ( |
531 | args.learning_rate * args.gradient_accumulation_steps * | 534 | args.learning_rate * args.gradient_accumulation_steps * |