diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 21:28:52 +0200 |
commit | 46b6c09a18b41edff77c6881529b66733d788abe (patch) | |
tree | 670e7cdda37ba7a010b570398a63dd38e357b6ce /dreambooth.py | |
parent | Small perf improvements (diff) | |
download | textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2 textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip |
Dreambooth: Generate specialized class images from input prompts
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 168 |
1 files changed, 70 insertions, 98 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9d6b8d6..2fe89ec 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -13,13 +13,12 @@ import torch.utils.checkpoint | |||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
17 | from schedulers.scheduling_euler_a import EulerAScheduler | 17 | from schedulers.scheduling_euler_a import EulerAScheduler |
18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
19 | from pipelines.stable_diffusion.no_check import NoCheck | ||
20 | from PIL import Image | 19 | from PIL import Image |
21 | from tqdm.auto import tqdm | 20 | from tqdm.auto import tqdm |
22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 22 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | import json | 24 | import json |
@@ -56,7 +55,13 @@ def parse_args(): | |||
56 | help="A folder containing the training data." | 55 | help="A folder containing the training data." |
57 | ) | 56 | ) |
58 | parser.add_argument( | 57 | parser.add_argument( |
59 | "--identifier", | 58 | "--instance_identifier", |
59 | type=str, | ||
60 | default=None, | ||
61 | help="A token to use as a placeholder for the concept.", | ||
62 | ) | ||
63 | parser.add_argument( | ||
64 | "--class_identifier", | ||
60 | type=str, | 65 | type=str, |
61 | default=None, | 66 | default=None, |
62 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
@@ -218,12 +223,6 @@ def parse_args(): | |||
218 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 223 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
219 | ) | 224 | ) |
220 | parser.add_argument( | 225 | parser.add_argument( |
221 | "--instance_prompt", | ||
222 | type=str, | ||
223 | default=None, | ||
224 | help="The prompt with identifier specifing the instance", | ||
225 | ) | ||
226 | parser.add_argument( | ||
227 | "--class_data_dir", | 226 | "--class_data_dir", |
228 | type=str, | 227 | type=str, |
229 | default=None, | 228 | default=None, |
@@ -231,12 +230,6 @@ def parse_args(): | |||
231 | help="A folder containing the training data of class images.", | 230 | help="A folder containing the training data of class images.", |
232 | ) | 231 | ) |
233 | parser.add_argument( | 232 | parser.add_argument( |
234 | "--class_prompt", | ||
235 | type=str, | ||
236 | default=None, | ||
237 | help="The prompt to specify images in the same class as provided intance images.", | ||
238 | ) | ||
239 | parser.add_argument( | ||
240 | "--prior_loss_weight", | 233 | "--prior_loss_weight", |
241 | type=float, | 234 | type=float, |
242 | default=1.0, | 235 | default=1.0, |
@@ -255,15 +248,6 @@ def parse_args(): | |||
255 | help="Max gradient norm." | 248 | help="Max gradient norm." |
256 | ) | 249 | ) |
257 | parser.add_argument( | 250 | parser.add_argument( |
258 | "--num_class_images", | ||
259 | type=int, | ||
260 | default=100, | ||
261 | help=( | ||
262 | "Minimal class images for prior perversation loss. If not have enough images, additional images will be" | ||
263 | " sampled with class_prompt." | ||
264 | ), | ||
265 | ) | ||
266 | parser.add_argument( | ||
267 | "--config", | 251 | "--config", |
268 | type=str, | 252 | type=str, |
269 | default=None, | 253 | default=None, |
@@ -286,21 +270,12 @@ def parse_args(): | |||
286 | if args.pretrained_model_name_or_path is None: | 270 | if args.pretrained_model_name_or_path is None: |
287 | raise ValueError("You must specify --pretrained_model_name_or_path") | 271 | raise ValueError("You must specify --pretrained_model_name_or_path") |
288 | 272 | ||
289 | if args.instance_prompt is None: | 273 | if args.instance_identifier is None: |
290 | raise ValueError("You must specify --instance_prompt") | 274 | raise ValueError("You must specify --instance_identifier") |
291 | |||
292 | if args.identifier is None: | ||
293 | raise ValueError("You must specify --identifier") | ||
294 | 275 | ||
295 | if args.output_dir is None: | 276 | if args.output_dir is None: |
296 | raise ValueError("You must specify --output_dir") | 277 | raise ValueError("You must specify --output_dir") |
297 | 278 | ||
298 | if args.with_prior_preservation: | ||
299 | if args.class_data_dir is None: | ||
300 | raise ValueError("You must specify --class_data_dir") | ||
301 | if args.class_prompt is None: | ||
302 | raise ValueError("You must specify --class_prompt") | ||
303 | |||
304 | return args | 279 | return args |
305 | 280 | ||
306 | 281 | ||
@@ -443,7 +418,7 @@ def main(): | |||
443 | args = parse_args() | 418 | args = parse_args() |
444 | 419 | ||
445 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 420 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
446 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) | 421 | basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) |
447 | basepath.mkdir(parents=True, exist_ok=True) | 422 | basepath.mkdir(parents=True, exist_ok=True) |
448 | 423 | ||
449 | accelerator = Accelerator( | 424 | accelerator = Accelerator( |
@@ -488,47 +463,6 @@ def main(): | |||
488 | freeze_params(vae.parameters()) | 463 | freeze_params(vae.parameters()) |
489 | freeze_params(text_encoder.parameters()) | 464 | freeze_params(text_encoder.parameters()) |
490 | 465 | ||
491 | # Generate class images, if necessary | ||
492 | if args.with_prior_preservation: | ||
493 | class_images_dir = Path(args.class_data_dir) | ||
494 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
495 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
496 | |||
497 | if cur_class_images < args.num_class_images: | ||
498 | scheduler = EulerAScheduler( | ||
499 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
500 | ) | ||
501 | |||
502 | pipeline = VlpnStableDiffusion( | ||
503 | text_encoder=text_encoder, | ||
504 | vae=vae, | ||
505 | unet=unet, | ||
506 | tokenizer=tokenizer, | ||
507 | scheduler=scheduler, | ||
508 | ).to(accelerator.device) | ||
509 | pipeline.enable_attention_slicing() | ||
510 | pipeline.set_progress_bar_config(disable=True) | ||
511 | |||
512 | num_new_images = args.num_class_images - cur_class_images | ||
513 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
514 | |||
515 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
516 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
517 | |||
518 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
519 | |||
520 | for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): | ||
521 | with accelerator.autocast(): | ||
522 | images = pipeline(example["prompt"]).images | ||
523 | |||
524 | for i, image in enumerate(images): | ||
525 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
526 | |||
527 | del pipeline | ||
528 | |||
529 | if torch.cuda.is_available(): | ||
530 | torch.cuda.empty_cache() | ||
531 | |||
532 | if args.scale_lr: | 466 | if args.scale_lr: |
533 | args.learning_rate = ( | 467 | args.learning_rate = ( |
534 | args.learning_rate * args.gradient_accumulation_steps * | 468 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -564,6 +498,7 @@ def main(): | |||
564 | 498 | ||
565 | def collate_fn(examples): | 499 | def collate_fn(examples): |
566 | prompts = [example["prompts"] for example in examples] | 500 | prompts = [example["prompts"] for example in examples] |
501 | nprompts = [example["nprompts"] for example in examples] | ||
567 | input_ids = [example["instance_prompt_ids"] for example in examples] | 502 | input_ids = [example["instance_prompt_ids"] for example in examples] |
568 | pixel_values = [example["instance_images"] for example in examples] | 503 | pixel_values = [example["instance_images"] for example in examples] |
569 | 504 | ||
@@ -579,6 +514,7 @@ def main(): | |||
579 | 514 | ||
580 | batch = { | 515 | batch = { |
581 | "prompts": prompts, | 516 | "prompts": prompts, |
517 | "nprompts": nprompts, | ||
582 | "input_ids": input_ids, | 518 | "input_ids": input_ids, |
583 | "pixel_values": pixel_values, | 519 | "pixel_values": pixel_values, |
584 | } | 520 | } |
@@ -588,11 +524,9 @@ def main(): | |||
588 | data_file=args.train_data_file, | 524 | data_file=args.train_data_file, |
589 | batch_size=args.train_batch_size, | 525 | batch_size=args.train_batch_size, |
590 | tokenizer=tokenizer, | 526 | tokenizer=tokenizer, |
591 | instance_prompt=args.instance_prompt, | 527 | instance_identifier=args.instance_identifier, |
592 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, | 528 | class_identifier=args.class_identifier, |
593 | class_prompt=args.class_prompt, | ||
594 | size=args.resolution, | 529 | size=args.resolution, |
595 | identifier=args.identifier, | ||
596 | repeats=args.repeats, | 530 | repeats=args.repeats, |
597 | center_crop=args.center_crop, | 531 | center_crop=args.center_crop, |
598 | valid_set_size=args.sample_batch_size*args.sample_batches, | 532 | valid_set_size=args.sample_batch_size*args.sample_batches, |
@@ -601,6 +535,46 @@ def main(): | |||
601 | datamodule.prepare_data() | 535 | datamodule.prepare_data() |
602 | datamodule.setup() | 536 | datamodule.setup() |
603 | 537 | ||
538 | if args.class_identifier: | ||
539 | missing_data = [item for item in datamodule.data if not item[1].exists()] | ||
540 | |||
541 | if len(missing_data) != 0: | ||
542 | batched_data = [missing_data[i:i+args.sample_batch_size] | ||
543 | for i in range(0, len(missing_data), args.sample_batch_size)] | ||
544 | |||
545 | scheduler = EulerAScheduler( | ||
546 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
547 | ) | ||
548 | |||
549 | pipeline = VlpnStableDiffusion( | ||
550 | text_encoder=text_encoder, | ||
551 | vae=vae, | ||
552 | unet=unet, | ||
553 | tokenizer=tokenizer, | ||
554 | scheduler=scheduler, | ||
555 | ).to(accelerator.device) | ||
556 | pipeline.enable_attention_slicing() | ||
557 | |||
558 | for batch in batched_data: | ||
559 | image_name = [p[1] for p in batch] | ||
560 | prompt = [p[2] for p in batch] | ||
561 | nprompt = [p[3] for p in batch] | ||
562 | |||
563 | with accelerator.autocast(): | ||
564 | images = pipeline( | ||
565 | prompt=prompt, | ||
566 | negative_prompt=nprompt, | ||
567 | num_inference_steps=args.sample_steps | ||
568 | ).images | ||
569 | |||
570 | for i, image in enumerate(images): | ||
571 | image.save(image_name[i]) | ||
572 | |||
573 | del pipeline | ||
574 | |||
575 | if torch.cuda.is_available(): | ||
576 | torch.cuda.empty_cache() | ||
577 | |||
604 | train_dataloader = datamodule.train_dataloader() | 578 | train_dataloader = datamodule.train_dataloader() |
605 | val_dataloader = datamodule.val_dataloader() | 579 | val_dataloader = datamodule.val_dataloader() |
606 | 580 | ||
@@ -718,23 +692,22 @@ def main(): | |||
718 | # Predict the noise residual | 692 | # Predict the noise residual |
719 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 693 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
720 | 694 | ||
721 | with accelerator.autocast(): | 695 | if args.with_prior_preservation: |
722 | if args.with_prior_preservation: | 696 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
723 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 697 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
724 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 698 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
725 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
726 | 699 | ||
727 | # Compute instance loss | 700 | # Compute instance loss |
728 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 701 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
729 | 702 | ||
730 | # Compute prior loss | 703 | # Compute prior loss |
731 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 704 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, |
732 | reduction="none").mean([1, 2, 3]).mean() | 705 | reduction="none").mean([1, 2, 3]).mean() |
733 | 706 | ||
734 | # Add the prior loss to the instance loss. | 707 | # Add the prior loss to the instance loss. |
735 | loss = loss + args.prior_loss_weight * prior_loss | 708 | loss = loss + args.prior_loss_weight * prior_loss |
736 | else: | 709 | else: |
737 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 710 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
738 | 711 | ||
739 | accelerator.backward(loss) | 712 | accelerator.backward(loss) |
740 | if accelerator.sync_gradients: | 713 | if accelerator.sync_gradients: |
@@ -786,8 +759,7 @@ def main(): | |||
786 | 759 | ||
787 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 760 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
788 | 761 | ||
789 | with accelerator.autocast(): | 762 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
790 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
791 | 763 | ||
792 | loss = loss.detach().item() | 764 | loss = loss.detach().item() |
793 | val_loss += loss | 765 | val_loss += loss |