diff options
author | Volpeon <git@volpeon.ink> | 2022-10-06 17:15:22 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-06 17:15:22 +0200 |
commit | 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch) | |
tree | 8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /textual_inversion.py | |
parent | Inference: Add support for embeddings (diff) | |
download | textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2 textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip |
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 112 |
1 files changed, 101 insertions, 11 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 7919ebd..11c324d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
25 | import json | 25 | import json |
26 | import os | 26 | import os |
27 | 27 | ||
28 | from data.textual_inversion.csv import CSVDataModule | 28 | from data.csv import CSVDataModule |
29 | 29 | ||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
31 | 31 | ||
@@ -68,10 +68,10 @@ def parse_args(): | |||
68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
69 | ) | 69 | ) |
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--vectors_per_token", | 71 | "--use_class_images", |
72 | type=int, | 72 | action="store_true", |
73 | default=1, | 73 | default=True, |
74 | help="Vectors per token." | 74 | help="Include class images in the loss calculation a la Dreambooth.", |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--repeats", | 77 | "--repeats", |
@@ -234,6 +234,12 @@ def parse_args(): | |||
234 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 234 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
235 | ) | 235 | ) |
236 | parser.add_argument( | 236 | parser.add_argument( |
237 | "--prior_loss_weight", | ||
238 | type=float, | ||
239 | default=1.0, | ||
240 | help="The weight of prior preservation loss." | ||
241 | ) | ||
242 | parser.add_argument( | ||
237 | "--resume_from", | 243 | "--resume_from", |
238 | type=str, | 244 | type=str, |
239 | default=None, | 245 | default=None, |
@@ -395,7 +401,8 @@ class Checkpointer: | |||
395 | 401 | ||
396 | for i in range(self.sample_batches): | 402 | for i in range(self.sample_batches): |
397 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 403 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
398 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 404 | prompt = [prompt.format(self.placeholder_token) |
405 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | ||
399 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 406 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
400 | 407 | ||
401 | with self.accelerator.autocast(): | 408 | with self.accelerator.autocast(): |
@@ -556,25 +563,94 @@ def main(): | |||
556 | eps=args.adam_epsilon, | 563 | eps=args.adam_epsilon, |
557 | ) | 564 | ) |
558 | 565 | ||
559 | # TODO (patil-suraj): laod scheduler using args | ||
560 | noise_scheduler = DDPMScheduler( | 566 | noise_scheduler = DDPMScheduler( |
561 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 | 567 | beta_start=0.00085, |
568 | beta_end=0.012, | ||
569 | beta_schedule="scaled_linear", | ||
570 | num_train_timesteps=1000 | ||
562 | ) | 571 | ) |
563 | 572 | ||
573 | def collate_fn(examples): | ||
574 | prompts = [example["prompts"] for example in examples] | ||
575 | nprompts = [example["nprompts"] for example in examples] | ||
576 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
577 | pixel_values = [example["instance_images"] for example in examples] | ||
578 | |||
579 | # concat class and instance examples for prior preservation | ||
580 | if args.use_class_images and "class_prompt_ids" in examples[0]: | ||
581 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
582 | pixel_values += [example["class_images"] for example in examples] | ||
583 | |||
584 | pixel_values = torch.stack(pixel_values) | ||
585 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | ||
586 | |||
587 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | ||
588 | |||
589 | batch = { | ||
590 | "prompts": prompts, | ||
591 | "nprompts": nprompts, | ||
592 | "input_ids": input_ids, | ||
593 | "pixel_values": pixel_values, | ||
594 | } | ||
595 | return batch | ||
596 | |||
564 | datamodule = CSVDataModule( | 597 | datamodule = CSVDataModule( |
565 | data_file=args.train_data_file, | 598 | data_file=args.train_data_file, |
566 | batch_size=args.train_batch_size, | 599 | batch_size=args.train_batch_size, |
567 | tokenizer=tokenizer, | 600 | tokenizer=tokenizer, |
601 | instance_identifier=args.placeholder_token, | ||
602 | class_identifier=args.initializer_token if args.use_class_images else None, | ||
603 | class_subdir="ti_cls", | ||
568 | size=args.resolution, | 604 | size=args.resolution, |
569 | placeholder_token=args.placeholder_token, | ||
570 | repeats=args.repeats, | 605 | repeats=args.repeats, |
571 | center_crop=args.center_crop, | 606 | center_crop=args.center_crop, |
572 | valid_set_size=args.sample_batch_size*args.sample_batches | 607 | valid_set_size=args.sample_batch_size*args.sample_batches, |
608 | collate_fn=collate_fn | ||
573 | ) | 609 | ) |
574 | 610 | ||
575 | datamodule.prepare_data() | 611 | datamodule.prepare_data() |
576 | datamodule.setup() | 612 | datamodule.setup() |
577 | 613 | ||
614 | if args.use_class_images: | ||
615 | missing_data = [item for item in datamodule.data if not item[1].exists()] | ||
616 | |||
617 | if len(missing_data) != 0: | ||
618 | batched_data = [missing_data[i:i+args.sample_batch_size] | ||
619 | for i in range(0, len(missing_data), args.sample_batch_size)] | ||
620 | |||
621 | scheduler = EulerAScheduler( | ||
622 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
623 | ) | ||
624 | |||
625 | pipeline = VlpnStableDiffusion( | ||
626 | text_encoder=text_encoder, | ||
627 | vae=vae, | ||
628 | unet=unet, | ||
629 | tokenizer=tokenizer, | ||
630 | scheduler=scheduler, | ||
631 | ).to(accelerator.device) | ||
632 | pipeline.enable_attention_slicing() | ||
633 | |||
634 | for batch in batched_data: | ||
635 | image_name = [p[1] for p in batch] | ||
636 | prompt = [p[2].format(args.initializer_token) for p in batch] | ||
637 | nprompt = [p[3] for p in batch] | ||
638 | |||
639 | with accelerator.autocast(): | ||
640 | images = pipeline( | ||
641 | prompt=prompt, | ||
642 | negative_prompt=nprompt, | ||
643 | num_inference_steps=args.sample_steps | ||
644 | ).images | ||
645 | |||
646 | for i, image in enumerate(images): | ||
647 | image.save(image_name[i]) | ||
648 | |||
649 | del pipeline | ||
650 | |||
651 | if torch.cuda.is_available(): | ||
652 | torch.cuda.empty_cache() | ||
653 | |||
578 | train_dataloader = datamodule.train_dataloader() | 654 | train_dataloader = datamodule.train_dataloader() |
579 | val_dataloader = datamodule.val_dataloader() | 655 | val_dataloader = datamodule.val_dataloader() |
580 | 656 | ||
@@ -693,7 +769,21 @@ def main(): | |||
693 | # Predict the noise residual | 769 | # Predict the noise residual |
694 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 770 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
695 | 771 | ||
696 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 772 | if args.use_class_images: |
773 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | ||
774 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
775 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
776 | |||
777 | # Compute instance loss | ||
778 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
779 | |||
780 | # Compute prior loss | ||
781 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() | ||
782 | |||
783 | # Add the prior loss to the instance loss. | ||
784 | loss = loss + args.prior_loss_weight * prior_loss | ||
785 | else: | ||
786 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
697 | 787 | ||
698 | accelerator.backward(loss) | 788 | accelerator.backward(loss) |
699 | 789 | ||