summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 21:28:52 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 21:28:52 +0200
commit46b6c09a18b41edff77c6881529b66733d788abe (patch)
tree670e7cdda37ba7a010b570398a63dd38e357b6ce /dreambooth.py
parentSmall perf improvements (diff)
downloadtextual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz
textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2
textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip
Dreambooth: Generate specialized class images from input prompts
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py168
1 files changed, 70 insertions, 98 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9d6b8d6..2fe89ec 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -13,13 +13,12 @@ import torch.utils.checkpoint
13from accelerate import Accelerator 13from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
17from schedulers.scheduling_euler_a import EulerAScheduler 17from schedulers.scheduling_euler_a import EulerAScheduler
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
19from pipelines.stable_diffusion.no_check import NoCheck
20from PIL import Image 19from PIL import Image
21from tqdm.auto import tqdm 20from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 21from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 22from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json 24import json
@@ -56,7 +55,13 @@ def parse_args():
56 help="A folder containing the training data." 55 help="A folder containing the training data."
57 ) 56 )
58 parser.add_argument( 57 parser.add_argument(
59 "--identifier", 58 "--instance_identifier",
59 type=str,
60 default=None,
61 help="A token to use as a placeholder for the concept.",
62 )
63 parser.add_argument(
64 "--class_identifier",
60 type=str, 65 type=str,
61 default=None, 66 default=None,
62 help="A token to use as a placeholder for the concept.", 67 help="A token to use as a placeholder for the concept.",
@@ -218,12 +223,6 @@ def parse_args():
218 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 223 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
219 ) 224 )
220 parser.add_argument( 225 parser.add_argument(
221 "--instance_prompt",
222 type=str,
223 default=None,
224 help="The prompt with identifier specifing the instance",
225 )
226 parser.add_argument(
227 "--class_data_dir", 226 "--class_data_dir",
228 type=str, 227 type=str,
229 default=None, 228 default=None,
@@ -231,12 +230,6 @@ def parse_args():
231 help="A folder containing the training data of class images.", 230 help="A folder containing the training data of class images.",
232 ) 231 )
233 parser.add_argument( 232 parser.add_argument(
234 "--class_prompt",
235 type=str,
236 default=None,
237 help="The prompt to specify images in the same class as provided intance images.",
238 )
239 parser.add_argument(
240 "--prior_loss_weight", 233 "--prior_loss_weight",
241 type=float, 234 type=float,
242 default=1.0, 235 default=1.0,
@@ -255,15 +248,6 @@ def parse_args():
255 help="Max gradient norm." 248 help="Max gradient norm."
256 ) 249 )
257 parser.add_argument( 250 parser.add_argument(
258 "--num_class_images",
259 type=int,
260 default=100,
261 help=(
262 "Minimal class images for prior perversation loss. If not have enough images, additional images will be"
263 " sampled with class_prompt."
264 ),
265 )
266 parser.add_argument(
267 "--config", 251 "--config",
268 type=str, 252 type=str,
269 default=None, 253 default=None,
@@ -286,21 +270,12 @@ def parse_args():
286 if args.pretrained_model_name_or_path is None: 270 if args.pretrained_model_name_or_path is None:
287 raise ValueError("You must specify --pretrained_model_name_or_path") 271 raise ValueError("You must specify --pretrained_model_name_or_path")
288 272
289 if args.instance_prompt is None: 273 if args.instance_identifier is None:
290 raise ValueError("You must specify --instance_prompt") 274 raise ValueError("You must specify --instance_identifier")
291
292 if args.identifier is None:
293 raise ValueError("You must specify --identifier")
294 275
295 if args.output_dir is None: 276 if args.output_dir is None:
296 raise ValueError("You must specify --output_dir") 277 raise ValueError("You must specify --output_dir")
297 278
298 if args.with_prior_preservation:
299 if args.class_data_dir is None:
300 raise ValueError("You must specify --class_data_dir")
301 if args.class_prompt is None:
302 raise ValueError("You must specify --class_prompt")
303
304 return args 279 return args
305 280
306 281
@@ -443,7 +418,7 @@ def main():
443 args = parse_args() 418 args = parse_args()
444 419
445 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 420 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
446 basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) 421 basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now)
447 basepath.mkdir(parents=True, exist_ok=True) 422 basepath.mkdir(parents=True, exist_ok=True)
448 423
449 accelerator = Accelerator( 424 accelerator = Accelerator(
@@ -488,47 +463,6 @@ def main():
488 freeze_params(vae.parameters()) 463 freeze_params(vae.parameters())
489 freeze_params(text_encoder.parameters()) 464 freeze_params(text_encoder.parameters())
490 465
491 # Generate class images, if necessary
492 if args.with_prior_preservation:
493 class_images_dir = Path(args.class_data_dir)
494 class_images_dir.mkdir(parents=True, exist_ok=True)
495 cur_class_images = len(list(class_images_dir.iterdir()))
496
497 if cur_class_images < args.num_class_images:
498 scheduler = EulerAScheduler(
499 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
500 )
501
502 pipeline = VlpnStableDiffusion(
503 text_encoder=text_encoder,
504 vae=vae,
505 unet=unet,
506 tokenizer=tokenizer,
507 scheduler=scheduler,
508 ).to(accelerator.device)
509 pipeline.enable_attention_slicing()
510 pipeline.set_progress_bar_config(disable=True)
511
512 num_new_images = args.num_class_images - cur_class_images
513 logger.info(f"Number of class images to sample: {num_new_images}.")
514
515 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
516 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
517
518 sample_dataloader = accelerator.prepare(sample_dataloader)
519
520 for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process):
521 with accelerator.autocast():
522 images = pipeline(example["prompt"]).images
523
524 for i, image in enumerate(images):
525 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
526
527 del pipeline
528
529 if torch.cuda.is_available():
530 torch.cuda.empty_cache()
531
532 if args.scale_lr: 466 if args.scale_lr:
533 args.learning_rate = ( 467 args.learning_rate = (
534 args.learning_rate * args.gradient_accumulation_steps * 468 args.learning_rate * args.gradient_accumulation_steps *
@@ -564,6 +498,7 @@ def main():
564 498
565 def collate_fn(examples): 499 def collate_fn(examples):
566 prompts = [example["prompts"] for example in examples] 500 prompts = [example["prompts"] for example in examples]
501 nprompts = [example["nprompts"] for example in examples]
567 input_ids = [example["instance_prompt_ids"] for example in examples] 502 input_ids = [example["instance_prompt_ids"] for example in examples]
568 pixel_values = [example["instance_images"] for example in examples] 503 pixel_values = [example["instance_images"] for example in examples]
569 504
@@ -579,6 +514,7 @@ def main():
579 514
580 batch = { 515 batch = {
581 "prompts": prompts, 516 "prompts": prompts,
517 "nprompts": nprompts,
582 "input_ids": input_ids, 518 "input_ids": input_ids,
583 "pixel_values": pixel_values, 519 "pixel_values": pixel_values,
584 } 520 }
@@ -588,11 +524,9 @@ def main():
588 data_file=args.train_data_file, 524 data_file=args.train_data_file,
589 batch_size=args.train_batch_size, 525 batch_size=args.train_batch_size,
590 tokenizer=tokenizer, 526 tokenizer=tokenizer,
591 instance_prompt=args.instance_prompt, 527 instance_identifier=args.instance_identifier,
592 class_data_root=args.class_data_dir if args.with_prior_preservation else None, 528 class_identifier=args.class_identifier,
593 class_prompt=args.class_prompt,
594 size=args.resolution, 529 size=args.resolution,
595 identifier=args.identifier,
596 repeats=args.repeats, 530 repeats=args.repeats,
597 center_crop=args.center_crop, 531 center_crop=args.center_crop,
598 valid_set_size=args.sample_batch_size*args.sample_batches, 532 valid_set_size=args.sample_batch_size*args.sample_batches,
@@ -601,6 +535,46 @@ def main():
601 datamodule.prepare_data() 535 datamodule.prepare_data()
602 datamodule.setup() 536 datamodule.setup()
603 537
538 if args.class_identifier:
539 missing_data = [item for item in datamodule.data if not item[1].exists()]
540
541 if len(missing_data) != 0:
542 batched_data = [missing_data[i:i+args.sample_batch_size]
543 for i in range(0, len(missing_data), args.sample_batch_size)]
544
545 scheduler = EulerAScheduler(
546 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
547 )
548
549 pipeline = VlpnStableDiffusion(
550 text_encoder=text_encoder,
551 vae=vae,
552 unet=unet,
553 tokenizer=tokenizer,
554 scheduler=scheduler,
555 ).to(accelerator.device)
556 pipeline.enable_attention_slicing()
557
558 for batch in batched_data:
559 image_name = [p[1] for p in batch]
560 prompt = [p[2] for p in batch]
561 nprompt = [p[3] for p in batch]
562
563 with accelerator.autocast():
564 images = pipeline(
565 prompt=prompt,
566 negative_prompt=nprompt,
567 num_inference_steps=args.sample_steps
568 ).images
569
570 for i, image in enumerate(images):
571 image.save(image_name[i])
572
573 del pipeline
574
575 if torch.cuda.is_available():
576 torch.cuda.empty_cache()
577
604 train_dataloader = datamodule.train_dataloader() 578 train_dataloader = datamodule.train_dataloader()
605 val_dataloader = datamodule.val_dataloader() 579 val_dataloader = datamodule.val_dataloader()
606 580
@@ -718,23 +692,22 @@ def main():
718 # Predict the noise residual 692 # Predict the noise residual
719 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 693 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
720 694
721 with accelerator.autocast(): 695 if args.with_prior_preservation:
722 if args.with_prior_preservation: 696 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
723 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 697 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
724 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 698 noise, noise_prior = torch.chunk(noise, 2, dim=0)
725 noise, noise_prior = torch.chunk(noise, 2, dim=0)
726 699
727 # Compute instance loss 700 # Compute instance loss
728 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 701 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
729 702
730 # Compute prior loss 703 # Compute prior loss
731 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, 704 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
732 reduction="none").mean([1, 2, 3]).mean() 705 reduction="none").mean([1, 2, 3]).mean()
733 706
734 # Add the prior loss to the instance loss. 707 # Add the prior loss to the instance loss.
735 loss = loss + args.prior_loss_weight * prior_loss 708 loss = loss + args.prior_loss_weight * prior_loss
736 else: 709 else:
737 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 710 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
738 711
739 accelerator.backward(loss) 712 accelerator.backward(loss)
740 if accelerator.sync_gradients: 713 if accelerator.sync_gradients:
@@ -786,8 +759,7 @@ def main():
786 759
787 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 760 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
788 761
789 with accelerator.autocast(): 762 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
790 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
791 763
792 loss = loss.detach().item() 764 loss = loss.detach().item()
793 val_loss += loss 765 val_loss += loss