summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py54
-rw-r--r--dreambooth.py28
-rw-r--r--infer.py5
-rw-r--r--textual_inversion.py48
4 files changed, 101 insertions, 34 deletions
diff --git a/data/csv.py b/data/csv.py
index abd329d..dcaf7d3 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,5 +1,3 @@
1import math
2import os
3import pandas as pd 1import pandas as pd
4from pathlib import Path 2from pathlib import Path
5import pytorch_lightning as pl 3import pytorch_lightning as pl
@@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule):
16 instance_identifier, 14 instance_identifier,
17 class_identifier=None, 15 class_identifier=None,
18 class_subdir="db_cls", 16 class_subdir="db_cls",
17 num_class_images=2,
19 size=512, 18 size=512,
20 repeats=100, 19 repeats=100,
21 interpolation="bicubic", 20 interpolation="bicubic",
@@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule):
33 self.data_root = self.data_file.parent 32 self.data_root = self.data_file.parent
34 self.class_root = self.data_root.joinpath(class_subdir) 33 self.class_root = self.data_root.joinpath(class_subdir)
35 self.class_root.mkdir(parents=True, exist_ok=True) 34 self.class_root.mkdir(parents=True, exist_ok=True)
35 self.num_class_images = num_class_images
36 36
37 self.tokenizer = tokenizer 37 self.tokenizer = tokenizer
38 self.instance_identifier = instance_identifier 38 self.instance_identifier = instance_identifier
@@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule):
48 48
49 def prepare_data(self): 49 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 50 metadata = pd.read_csv(self.data_file)
51 instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] 51 instance_image_paths = [
52 class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] 52 self.data_root.joinpath(f)
53 prompts = metadata['prompt'].values 53 for f in metadata['image'].values
54 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) 54 for i in range(self.num_class_images)
55 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) 55 ]
56 self.data = [(i, c, p, n) 56 class_image_paths = [
57 for i, c, p, n, s 57 self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}")
58 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) 58 for f in metadata['image'].values
59 if s != "x"] 59 for i in range(self.num_class_images)
60 ]
61 prompts = [
62 prompt
63 for prompt in metadata['prompt'].values
64 for i in range(self.num_class_images)
65 ]
66 nprompts = [
67 nprompt
68 for nprompt in metadata['nprompt'].values
69 for i in range(self.num_class_images)
70 ] if 'nprompt' in metadata else [""] * len(instance_image_paths)
71 skips = [
72 skip
73 for skip in metadata['skip'].values
74 for i in range(self.num_class_images)
75 ] if 'skip' in metadata else [""] * len(instance_image_paths)
76 self.data = [
77 (i, c, p, n)
78 for i, c, p, n, s
79 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
80 if s != "x"
81 ]
60 82
61 def setup(self, stage=None): 83 def setup(self, stage=None):
62 valid_set_size = int(len(self.data) * 0.2) 84 valid_set_size = int(len(self.data) * 0.2)
@@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule):
69 91
70 train_dataset = CSVDataset(self.data_train, self.tokenizer, 92 train_dataset = CSVDataset(self.data_train, self.tokenizer,
71 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 93 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
94 num_class_images=self.num_class_images,
72 size=self.size, interpolation=self.interpolation, 95 size=self.size, interpolation=self.interpolation,
73 center_crop=self.center_crop, repeats=self.repeats) 96 center_crop=self.center_crop, repeats=self.repeats)
74 val_dataset = CSVDataset(self.data_val, self.tokenizer, 97 val_dataset = CSVDataset(self.data_val, self.tokenizer,
@@ -93,6 +116,7 @@ class CSVDataset(Dataset):
93 tokenizer, 116 tokenizer,
94 instance_identifier, 117 instance_identifier,
95 class_identifier=None, 118 class_identifier=None,
119 num_class_images=2,
96 size=512, 120 size=512,
97 repeats=1, 121 repeats=1,
98 interpolation="bicubic", 122 interpolation="bicubic",
@@ -103,6 +127,7 @@ class CSVDataset(Dataset):
103 self.tokenizer = tokenizer 127 self.tokenizer = tokenizer
104 self.instance_identifier = instance_identifier 128 self.instance_identifier = instance_identifier
105 self.class_identifier = class_identifier 129 self.class_identifier = class_identifier
130 self.num_class_images = num_class_images
106 self.cache = {} 131 self.cache = {}
107 132
108 self.num_instance_images = len(self.data) 133 self.num_instance_images = len(self.data)
@@ -128,9 +153,10 @@ class CSVDataset(Dataset):
128 153
129 def get_example(self, i): 154 def get_example(self, i):
130 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] 155 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images]
156 cache_key = f"{instance_image_path}_{class_image_path}"
131 157
132 if instance_image_path in self.cache: 158 if cache_key in self.cache:
133 return self.cache[instance_image_path] 159 return self.cache[cache_key]
134 160
135 example = {} 161 example = {}
136 162
@@ -149,7 +175,7 @@ class CSVDataset(Dataset):
149 max_length=self.tokenizer.model_max_length, 175 max_length=self.tokenizer.model_max_length,
150 ).input_ids 176 ).input_ids
151 177
152 if self.class_identifier is not None: 178 if self.num_class_images != 0:
153 class_image = Image.open(class_image_path) 179 class_image = Image.open(class_image_path)
154 if not class_image.mode == "RGB": 180 if not class_image.mode == "RGB":
155 class_image = class_image.convert("RGB") 181 class_image = class_image.convert("RGB")
diff --git a/dreambooth.py b/dreambooth.py
index 0e69d79..24e6091 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -66,6 +66,12 @@ def parse_args():
66 help="A token to use as a placeholder for the concept.", 66 help="A token to use as a placeholder for the concept.",
67 ) 67 )
68 parser.add_argument( 68 parser.add_argument(
69 "--num_class_images",
70 type=int,
71 default=2,
72 help="How many class images to generate per training image."
73 )
74 parser.add_argument(
69 "--repeats", 75 "--repeats",
70 type=int, 76 type=int,
71 default=1, 77 default=1,
@@ -347,6 +353,7 @@ class Checkpointer:
347 scheduler=scheduler, 353 scheduler=scheduler,
348 ).to(self.accelerator.device) 354 ).to(self.accelerator.device)
349 pipeline.enable_attention_slicing() 355 pipeline.enable_attention_slicing()
356 pipeline.set_progress_bar_config(dynamic_ncols=True)
350 357
351 train_data = self.datamodule.train_dataloader() 358 train_data = self.datamodule.train_dataloader()
352 val_data = self.datamodule.val_dataloader() 359 val_data = self.datamodule.val_dataloader()
@@ -494,7 +501,7 @@ def main():
494 pixel_values = [example["instance_images"] for example in examples] 501 pixel_values = [example["instance_images"] for example in examples]
495 502
496 # concat class and instance examples for prior preservation 503 # concat class and instance examples for prior preservation
497 if args.class_identifier is not None and "class_prompt_ids" in examples[0]: 504 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
498 input_ids += [example["class_prompt_ids"] for example in examples] 505 input_ids += [example["class_prompt_ids"] for example in examples]
499 pixel_values += [example["class_images"] for example in examples] 506 pixel_values += [example["class_images"] for example in examples]
500 507
@@ -518,6 +525,7 @@ def main():
518 instance_identifier=args.instance_identifier, 525 instance_identifier=args.instance_identifier,
519 class_identifier=args.class_identifier, 526 class_identifier=args.class_identifier,
520 class_subdir="db_cls", 527 class_subdir="db_cls",
528 num_class_images=args.num_class_images,
521 size=args.resolution, 529 size=args.resolution,
522 repeats=args.repeats, 530 repeats=args.repeats,
523 center_crop=args.center_crop, 531 center_crop=args.center_crop,
@@ -528,7 +536,7 @@ def main():
528 datamodule.prepare_data() 536 datamodule.prepare_data()
529 datamodule.setup() 537 datamodule.setup()
530 538
531 if args.class_identifier is not None: 539 if args.num_class_images != 0:
532 missing_data = [item for item in datamodule.data if not item[1].exists()] 540 missing_data = [item for item in datamodule.data if not item[1].exists()]
533 541
534 if len(missing_data) != 0: 542 if len(missing_data) != 0:
@@ -547,6 +555,7 @@ def main():
547 scheduler=scheduler, 555 scheduler=scheduler,
548 ).to(accelerator.device) 556 ).to(accelerator.device)
549 pipeline.enable_attention_slicing() 557 pipeline.enable_attention_slicing()
558 pipeline.set_progress_bar_config(dynamic_ncols=True)
550 559
551 for batch in batched_data: 560 for batch in batched_data:
552 image_name = [p[1] for p in batch] 561 image_name = [p[1] for p in batch]
@@ -645,11 +654,18 @@ def main():
645 0, 654 0,
646 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 655 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
647 656
648 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), 657 local_progress_bar = tqdm(
649 disable=not accelerator.is_local_main_process) 658 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
659 disable=not accelerator.is_local_main_process,
660 dynamic_ncols=True
661 )
650 local_progress_bar.set_description("Batch X out of Y") 662 local_progress_bar.set_description("Batch X out of Y")
651 663
652 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) 664 global_progress_bar = tqdm(
665 range(args.max_train_steps + val_steps),
666 disable=not accelerator.is_local_main_process,
667 dynamic_ncols=True
668 )
653 global_progress_bar.set_description("Total progress") 669 global_progress_bar.set_description("Total progress")
654 670
655 try: 671 try:
@@ -686,7 +702,7 @@ def main():
686 # Predict the noise residual 702 # Predict the noise residual
687 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 703 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
688 704
689 if args.class_identifier is not None: 705 if args.num_class_images != 0:
690 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 706 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
691 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 707 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
692 noise, noise_prior = torch.chunk(noise, 2, dim=0) 708 noise, noise_prior = torch.chunk(noise, 2, dim=0)
diff --git a/infer.py b/infer.py
index 34e570a..6197aa3 100644
--- a/infer.py
+++ b/infer.py
@@ -262,7 +262,10 @@ def generate(output_dir, pipeline, args):
262 262
263 with autocast("cuda"): 263 with autocast("cuda"):
264 for i in range(args.batch_num): 264 for i in range(args.batch_num):
265 pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") 265 pipeline.set_progress_bar_config(
266 desc=f"Batch {i + 1} of {args.batch_num}",
267 dynamic_ncols=True
268 )
266 269
267 generator = torch.Generator(device="cuda").manual_seed(seed + i) 270 generator = torch.Generator(device="cuda").manual_seed(seed + i)
268 images = pipeline( 271 images = pipeline(
diff --git a/textual_inversion.py b/textual_inversion.py
index 11c324d..86fcdfe 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -68,16 +68,17 @@ def parse_args():
68 help="A token to use as initializer word." 68 help="A token to use as initializer word."
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--use_class_images", 71 "--num_class_images",
72 action="store_true", 72 type=int,
73 default=True, 73 default=2,
74 help="Include class images in the loss calculation a la Dreambooth.", 74 help="How many class images to generate per training image."
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--repeats", 77 "--repeats",
78 type=int, 78 type=int,
79 default=100, 79 default=100,
80 help="How many times to repeat the training data.") 80 help="How many times to repeat the training data."
81 )
81 parser.add_argument( 82 parser.add_argument(
82 "--output_dir", 83 "--output_dir",
83 type=str, 84 type=str,
@@ -204,6 +205,12 @@ def parse_args():
204 help="How often to save a checkpoint and sample image", 205 help="How often to save a checkpoint and sample image",
205 ) 206 )
206 parser.add_argument( 207 parser.add_argument(
208 "--sample_frequency",
209 type=int,
210 default=100,
211 help="How often to save a checkpoint and sample image",
212 )
213 parser.add_argument(
207 "--sample_image_size", 214 "--sample_image_size",
208 type=int, 215 type=int,
209 default=512, 216 default=512,
@@ -381,6 +388,7 @@ class Checkpointer:
381 scheduler=scheduler, 388 scheduler=scheduler,
382 ).to(self.accelerator.device) 389 ).to(self.accelerator.device)
383 pipeline.enable_attention_slicing() 390 pipeline.enable_attention_slicing()
391 pipeline.set_progress_bar_config(dynamic_ncols=True)
384 392
385 train_data = self.datamodule.train_dataloader() 393 train_data = self.datamodule.train_dataloader()
386 val_data = self.datamodule.val_dataloader() 394 val_data = self.datamodule.val_dataloader()
@@ -577,7 +585,7 @@ def main():
577 pixel_values = [example["instance_images"] for example in examples] 585 pixel_values = [example["instance_images"] for example in examples]
578 586
579 # concat class and instance examples for prior preservation 587 # concat class and instance examples for prior preservation
580 if args.use_class_images and "class_prompt_ids" in examples[0]: 588 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
581 input_ids += [example["class_prompt_ids"] for example in examples] 589 input_ids += [example["class_prompt_ids"] for example in examples]
582 pixel_values += [example["class_images"] for example in examples] 590 pixel_values += [example["class_images"] for example in examples]
583 591
@@ -599,8 +607,9 @@ def main():
599 batch_size=args.train_batch_size, 607 batch_size=args.train_batch_size,
600 tokenizer=tokenizer, 608 tokenizer=tokenizer,
601 instance_identifier=args.placeholder_token, 609 instance_identifier=args.placeholder_token,
602 class_identifier=args.initializer_token if args.use_class_images else None, 610 class_identifier=args.initializer_token,
603 class_subdir="ti_cls", 611 class_subdir="ti_cls",
612 num_class_images=args.num_class_images,
604 size=args.resolution, 613 size=args.resolution,
605 repeats=args.repeats, 614 repeats=args.repeats,
606 center_crop=args.center_crop, 615 center_crop=args.center_crop,
@@ -611,7 +620,7 @@ def main():
611 datamodule.prepare_data() 620 datamodule.prepare_data()
612 datamodule.setup() 621 datamodule.setup()
613 622
614 if args.use_class_images: 623 if args.num_class_images != 0:
615 missing_data = [item for item in datamodule.data if not item[1].exists()] 624 missing_data = [item for item in datamodule.data if not item[1].exists()]
616 625
617 if len(missing_data) != 0: 626 if len(missing_data) != 0:
@@ -630,6 +639,7 @@ def main():
630 scheduler=scheduler, 639 scheduler=scheduler,
631 ).to(accelerator.device) 640 ).to(accelerator.device)
632 pipeline.enable_attention_slicing() 641 pipeline.enable_attention_slicing()
642 pipeline.set_progress_bar_config(dynamic_ncols=True)
633 643
634 for batch in batched_data: 644 for batch in batched_data:
635 image_name = [p[1] for p in batch] 645 image_name = [p[1] for p in batch]
@@ -729,11 +739,18 @@ def main():
729 text_encoder, 739 text_encoder,
730 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 740 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
731 741
732 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), 742 local_progress_bar = tqdm(
733 disable=not accelerator.is_local_main_process) 743 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
744 disable=not accelerator.is_local_main_process,
745 dynamic_ncols=True
746 )
734 local_progress_bar.set_description("Batch X out of Y") 747 local_progress_bar.set_description("Batch X out of Y")
735 748
736 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) 749 global_progress_bar = tqdm(
750 range(args.max_train_steps + val_steps),
751 disable=not accelerator.is_local_main_process,
752 dynamic_ncols=True
753 )
737 global_progress_bar.set_description("Total progress") 754 global_progress_bar.set_description("Total progress")
738 755
739 try: 756 try:
@@ -744,6 +761,8 @@ def main():
744 text_encoder.train() 761 text_encoder.train()
745 train_loss = 0.0 762 train_loss = 0.0
746 763
764 sample_checkpoint = False
765
747 for step, batch in enumerate(train_dataloader): 766 for step, batch in enumerate(train_dataloader):
748 with accelerator.accumulate(text_encoder): 767 with accelerator.accumulate(text_encoder):
749 # Convert images to latent space 768 # Convert images to latent space
@@ -769,7 +788,7 @@ def main():
769 # Predict the noise residual 788 # Predict the noise residual
770 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 789 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
771 790
772 if args.use_class_images: 791 if args.num_class_images != 0:
773 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 792 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
774 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 793 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
775 noise, noise_prior = torch.chunk(noise, 2, dim=0) 794 noise, noise_prior = torch.chunk(noise, 2, dim=0)
@@ -812,6 +831,9 @@ def main():
812 831
813 global_step += 1 832 global_step += 1
814 833
834 if global_step % args.sample_frequency == 0:
835 sample_checkpoint = True
836
815 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 837 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
816 local_progress_bar.clear() 838 local_progress_bar.clear()
817 global_progress_bar.clear() 839 global_progress_bar.clear()
@@ -878,7 +900,7 @@ def main():
878 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) 900 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
879 min_val_loss = val_loss 901 min_val_loss = val_loss
880 902
881 if accelerator.is_main_process: 903 if sample_checkpoint and accelerator.is_main_process:
882 checkpointer.save_samples( 904 checkpointer.save_samples(
883 global_step + global_step_offset, 905 global_step + global_step_offset,
884 text_encoder, 906 text_encoder,