summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-06 17:15:22 +0200
commit49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 (patch)
tree8bd8fe59b2a5b60c2f6e7e1b48b53be7fbf1e155 /textual_inversion.py
parentInference: Add support for embeddings (diff)
downloadtextual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.gz
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.tar.bz2
textual-inversion-diff-49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py112
1 files changed, 101 insertions, 11 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 7919ebd..11c324d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json 25import json
26import os 26import os
27 27
28from data.textual_inversion.csv import CSVDataModule 28from data.csv import CSVDataModule
29 29
30logger = get_logger(__name__) 30logger = get_logger(__name__)
31 31
@@ -68,10 +68,10 @@ def parse_args():
68 help="A token to use as initializer word." 68 help="A token to use as initializer word."
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--vectors_per_token", 71 "--use_class_images",
72 type=int, 72 action="store_true",
73 default=1, 73 default=True,
74 help="Vectors per token." 74 help="Include class images in the loss calculation a la Dreambooth.",
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--repeats", 77 "--repeats",
@@ -234,6 +234,12 @@ def parse_args():
234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
235 ) 235 )
236 parser.add_argument( 236 parser.add_argument(
237 "--prior_loss_weight",
238 type=float,
239 default=1.0,
240 help="The weight of prior preservation loss."
241 )
242 parser.add_argument(
237 "--resume_from", 243 "--resume_from",
238 type=str, 244 type=str,
239 default=None, 245 default=None,
@@ -395,7 +401,8 @@ class Checkpointer:
395 401
396 for i in range(self.sample_batches): 402 for i in range(self.sample_batches):
397 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 403 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
398 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 404 prompt = [prompt.format(self.placeholder_token)
405 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
399 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 406 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
400 407
401 with self.accelerator.autocast(): 408 with self.accelerator.autocast():
@@ -556,25 +563,94 @@ def main():
556 eps=args.adam_epsilon, 563 eps=args.adam_epsilon,
557 ) 564 )
558 565
559 # TODO (patil-suraj): laod scheduler using args
560 noise_scheduler = DDPMScheduler( 566 noise_scheduler = DDPMScheduler(
561 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 567 beta_start=0.00085,
568 beta_end=0.012,
569 beta_schedule="scaled_linear",
570 num_train_timesteps=1000
562 ) 571 )
563 572
573 def collate_fn(examples):
574 prompts = [example["prompts"] for example in examples]
575 nprompts = [example["nprompts"] for example in examples]
576 input_ids = [example["instance_prompt_ids"] for example in examples]
577 pixel_values = [example["instance_images"] for example in examples]
578
579 # concat class and instance examples for prior preservation
580 if args.use_class_images and "class_prompt_ids" in examples[0]:
581 input_ids += [example["class_prompt_ids"] for example in examples]
582 pixel_values += [example["class_images"] for example in examples]
583
584 pixel_values = torch.stack(pixel_values)
585 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
586
587 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
588
589 batch = {
590 "prompts": prompts,
591 "nprompts": nprompts,
592 "input_ids": input_ids,
593 "pixel_values": pixel_values,
594 }
595 return batch
596
564 datamodule = CSVDataModule( 597 datamodule = CSVDataModule(
565 data_file=args.train_data_file, 598 data_file=args.train_data_file,
566 batch_size=args.train_batch_size, 599 batch_size=args.train_batch_size,
567 tokenizer=tokenizer, 600 tokenizer=tokenizer,
601 instance_identifier=args.placeholder_token,
602 class_identifier=args.initializer_token if args.use_class_images else None,
603 class_subdir="ti_cls",
568 size=args.resolution, 604 size=args.resolution,
569 placeholder_token=args.placeholder_token,
570 repeats=args.repeats, 605 repeats=args.repeats,
571 center_crop=args.center_crop, 606 center_crop=args.center_crop,
572 valid_set_size=args.sample_batch_size*args.sample_batches 607 valid_set_size=args.sample_batch_size*args.sample_batches,
608 collate_fn=collate_fn
573 ) 609 )
574 610
575 datamodule.prepare_data() 611 datamodule.prepare_data()
576 datamodule.setup() 612 datamodule.setup()
577 613
614 if args.use_class_images:
615 missing_data = [item for item in datamodule.data if not item[1].exists()]
616
617 if len(missing_data) != 0:
618 batched_data = [missing_data[i:i+args.sample_batch_size]
619 for i in range(0, len(missing_data), args.sample_batch_size)]
620
621 scheduler = EulerAScheduler(
622 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
623 )
624
625 pipeline = VlpnStableDiffusion(
626 text_encoder=text_encoder,
627 vae=vae,
628 unet=unet,
629 tokenizer=tokenizer,
630 scheduler=scheduler,
631 ).to(accelerator.device)
632 pipeline.enable_attention_slicing()
633
634 for batch in batched_data:
635 image_name = [p[1] for p in batch]
636 prompt = [p[2].format(args.initializer_token) for p in batch]
637 nprompt = [p[3] for p in batch]
638
639 with accelerator.autocast():
640 images = pipeline(
641 prompt=prompt,
642 negative_prompt=nprompt,
643 num_inference_steps=args.sample_steps
644 ).images
645
646 for i, image in enumerate(images):
647 image.save(image_name[i])
648
649 del pipeline
650
651 if torch.cuda.is_available():
652 torch.cuda.empty_cache()
653
578 train_dataloader = datamodule.train_dataloader() 654 train_dataloader = datamodule.train_dataloader()
579 val_dataloader = datamodule.val_dataloader() 655 val_dataloader = datamodule.val_dataloader()
580 656
@@ -693,7 +769,21 @@ def main():
693 # Predict the noise residual 769 # Predict the noise residual
694 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 770 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
695 771
696 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 772 if args.use_class_images:
773 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
774 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
775 noise, noise_prior = torch.chunk(noise, 2, dim=0)
776
777 # Compute instance loss
778 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
779
780 # Compute prior loss
781 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
782
783 # Add the prior loss to the instance loss.
784 loss = loss + args.prior_loss_weight * prior_loss
785 else:
786 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
697 787
698 accelerator.backward(loss) 788 accelerator.backward(loss)
699 789