diff options
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 112 |
1 files changed, 101 insertions, 11 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 7919ebd..11c324d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 25 | import json | 25 | import json |
| 26 | import os | 26 | import os |
| 27 | 27 | ||
| 28 | from data.textual_inversion.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
| 29 | 29 | ||
| 30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| 31 | 31 | ||
| @@ -68,10 +68,10 @@ def parse_args(): | |||
| 68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
| 69 | ) | 69 | ) |
| 70 | parser.add_argument( | 70 | parser.add_argument( |
| 71 | "--vectors_per_token", | 71 | "--use_class_images", |
| 72 | type=int, | 72 | action="store_true", |
| 73 | default=1, | 73 | default=True, |
| 74 | help="Vectors per token." | 74 | help="Include class images in the loss calculation a la Dreambooth.", |
| 75 | ) | 75 | ) |
| 76 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--repeats", | 77 | "--repeats", |
| @@ -234,6 +234,12 @@ def parse_args(): | |||
| 234 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 234 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 235 | ) | 235 | ) |
| 236 | parser.add_argument( | 236 | parser.add_argument( |
| 237 | "--prior_loss_weight", | ||
| 238 | type=float, | ||
| 239 | default=1.0, | ||
| 240 | help="The weight of prior preservation loss." | ||
| 241 | ) | ||
| 242 | parser.add_argument( | ||
| 237 | "--resume_from", | 243 | "--resume_from", |
| 238 | type=str, | 244 | type=str, |
| 239 | default=None, | 245 | default=None, |
| @@ -395,7 +401,8 @@ class Checkpointer: | |||
| 395 | 401 | ||
| 396 | for i in range(self.sample_batches): | 402 | for i in range(self.sample_batches): |
| 397 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 403 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 398 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 404 | prompt = [prompt.format(self.placeholder_token) |
| 405 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | ||
| 399 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 406 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 400 | 407 | ||
| 401 | with self.accelerator.autocast(): | 408 | with self.accelerator.autocast(): |
| @@ -556,25 +563,94 @@ def main(): | |||
| 556 | eps=args.adam_epsilon, | 563 | eps=args.adam_epsilon, |
| 557 | ) | 564 | ) |
| 558 | 565 | ||
| 559 | # TODO (patil-suraj): laod scheduler using args | ||
| 560 | noise_scheduler = DDPMScheduler( | 566 | noise_scheduler = DDPMScheduler( |
| 561 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 | 567 | beta_start=0.00085, |
| 568 | beta_end=0.012, | ||
| 569 | beta_schedule="scaled_linear", | ||
| 570 | num_train_timesteps=1000 | ||
| 562 | ) | 571 | ) |
| 563 | 572 | ||
| 573 | def collate_fn(examples): | ||
| 574 | prompts = [example["prompts"] for example in examples] | ||
| 575 | nprompts = [example["nprompts"] for example in examples] | ||
| 576 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
| 577 | pixel_values = [example["instance_images"] for example in examples] | ||
| 578 | |||
| 579 | # concat class and instance examples for prior preservation | ||
| 580 | if args.use_class_images and "class_prompt_ids" in examples[0]: | ||
| 581 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
| 582 | pixel_values += [example["class_images"] for example in examples] | ||
| 583 | |||
| 584 | pixel_values = torch.stack(pixel_values) | ||
| 585 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | ||
| 586 | |||
| 587 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | ||
| 588 | |||
| 589 | batch = { | ||
| 590 | "prompts": prompts, | ||
| 591 | "nprompts": nprompts, | ||
| 592 | "input_ids": input_ids, | ||
| 593 | "pixel_values": pixel_values, | ||
| 594 | } | ||
| 595 | return batch | ||
| 596 | |||
| 564 | datamodule = CSVDataModule( | 597 | datamodule = CSVDataModule( |
| 565 | data_file=args.train_data_file, | 598 | data_file=args.train_data_file, |
| 566 | batch_size=args.train_batch_size, | 599 | batch_size=args.train_batch_size, |
| 567 | tokenizer=tokenizer, | 600 | tokenizer=tokenizer, |
| 601 | instance_identifier=args.placeholder_token, | ||
| 602 | class_identifier=args.initializer_token if args.use_class_images else None, | ||
| 603 | class_subdir="ti_cls", | ||
| 568 | size=args.resolution, | 604 | size=args.resolution, |
| 569 | placeholder_token=args.placeholder_token, | ||
| 570 | repeats=args.repeats, | 605 | repeats=args.repeats, |
| 571 | center_crop=args.center_crop, | 606 | center_crop=args.center_crop, |
| 572 | valid_set_size=args.sample_batch_size*args.sample_batches | 607 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| 608 | collate_fn=collate_fn | ||
| 573 | ) | 609 | ) |
| 574 | 610 | ||
| 575 | datamodule.prepare_data() | 611 | datamodule.prepare_data() |
| 576 | datamodule.setup() | 612 | datamodule.setup() |
| 577 | 613 | ||
| 614 | if args.use_class_images: | ||
| 615 | missing_data = [item for item in datamodule.data if not item[1].exists()] | ||
| 616 | |||
| 617 | if len(missing_data) != 0: | ||
| 618 | batched_data = [missing_data[i:i+args.sample_batch_size] | ||
| 619 | for i in range(0, len(missing_data), args.sample_batch_size)] | ||
| 620 | |||
| 621 | scheduler = EulerAScheduler( | ||
| 622 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 623 | ) | ||
| 624 | |||
| 625 | pipeline = VlpnStableDiffusion( | ||
| 626 | text_encoder=text_encoder, | ||
| 627 | vae=vae, | ||
| 628 | unet=unet, | ||
| 629 | tokenizer=tokenizer, | ||
| 630 | scheduler=scheduler, | ||
| 631 | ).to(accelerator.device) | ||
| 632 | pipeline.enable_attention_slicing() | ||
| 633 | |||
| 634 | for batch in batched_data: | ||
| 635 | image_name = [p[1] for p in batch] | ||
| 636 | prompt = [p[2].format(args.initializer_token) for p in batch] | ||
| 637 | nprompt = [p[3] for p in batch] | ||
| 638 | |||
| 639 | with accelerator.autocast(): | ||
| 640 | images = pipeline( | ||
| 641 | prompt=prompt, | ||
| 642 | negative_prompt=nprompt, | ||
| 643 | num_inference_steps=args.sample_steps | ||
| 644 | ).images | ||
| 645 | |||
| 646 | for i, image in enumerate(images): | ||
| 647 | image.save(image_name[i]) | ||
| 648 | |||
| 649 | del pipeline | ||
| 650 | |||
| 651 | if torch.cuda.is_available(): | ||
| 652 | torch.cuda.empty_cache() | ||
| 653 | |||
| 578 | train_dataloader = datamodule.train_dataloader() | 654 | train_dataloader = datamodule.train_dataloader() |
| 579 | val_dataloader = datamodule.val_dataloader() | 655 | val_dataloader = datamodule.val_dataloader() |
| 580 | 656 | ||
| @@ -693,7 +769,21 @@ def main(): | |||
| 693 | # Predict the noise residual | 769 | # Predict the noise residual |
| 694 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 770 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 695 | 771 | ||
| 696 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 772 | if args.use_class_images: |
| 773 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
| 774 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 775 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 776 | |||
| 777 | # Compute instance loss | ||
| 778 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 779 | |||
| 780 | # Compute prior loss | ||
| 781 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | ||
| 782 | |||
| 783 | # Add the prior loss to the instance loss. | ||
| 784 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 785 | else: | ||
| 786 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 697 | 787 | ||
| 698 | accelerator.backward(loss) | 788 | accelerator.backward(loss) |
| 699 | 789 | ||
