diff options
-rw-r--r-- | data/csv.py | 54 | ||||
-rw-r--r-- | dreambooth.py | 28 | ||||
-rw-r--r-- | infer.py | 5 | ||||
-rw-r--r-- | textual_inversion.py | 48 |
4 files changed, 101 insertions, 34 deletions
diff --git a/data/csv.py b/data/csv.py index abd329d..dcaf7d3 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,5 +1,3 @@ | |||
1 | import math | ||
2 | import os | ||
3 | import pandas as pd | 1 | import pandas as pd |
4 | from pathlib import Path | 2 | from pathlib import Path |
5 | import pytorch_lightning as pl | 3 | import pytorch_lightning as pl |
@@ -16,6 +14,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
16 | instance_identifier, | 14 | instance_identifier, |
17 | class_identifier=None, | 15 | class_identifier=None, |
18 | class_subdir="db_cls", | 16 | class_subdir="db_cls", |
17 | num_class_images=2, | ||
19 | size=512, | 18 | size=512, |
20 | repeats=100, | 19 | repeats=100, |
21 | interpolation="bicubic", | 20 | interpolation="bicubic", |
@@ -33,6 +32,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
33 | self.data_root = self.data_file.parent | 32 | self.data_root = self.data_file.parent |
34 | self.class_root = self.data_root.joinpath(class_subdir) | 33 | self.class_root = self.data_root.joinpath(class_subdir) |
35 | self.class_root.mkdir(parents=True, exist_ok=True) | 34 | self.class_root.mkdir(parents=True, exist_ok=True) |
35 | self.num_class_images = num_class_images | ||
36 | 36 | ||
37 | self.tokenizer = tokenizer | 37 | self.tokenizer = tokenizer |
38 | self.instance_identifier = instance_identifier | 38 | self.instance_identifier = instance_identifier |
@@ -48,15 +48,37 @@ class CSVDataModule(pl.LightningDataModule): | |||
48 | 48 | ||
49 | def prepare_data(self): | 49 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
51 | instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] | 51 | instance_image_paths = [ |
52 | class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] | 52 | self.data_root.joinpath(f) |
53 | prompts = metadata['prompt'].values | 53 | for f in metadata['image'].values |
54 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) | 54 | for i in range(self.num_class_images) |
55 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) | 55 | ] |
56 | self.data = [(i, c, p, n) | 56 | class_image_paths = [ |
57 | for i, c, p, n, s | 57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") |
58 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | 58 | for f in metadata['image'].values |
59 | if s != "x"] | 59 | for i in range(self.num_class_images) |
60 | ] | ||
61 | prompts = [ | ||
62 | prompt | ||
63 | for prompt in metadata['prompt'].values | ||
64 | for i in range(self.num_class_images) | ||
65 | ] | ||
66 | nprompts = [ | ||
67 | nprompt | ||
68 | for nprompt in metadata['nprompt'].values | ||
69 | for i in range(self.num_class_images) | ||
70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
71 | skips = [ | ||
72 | skip | ||
73 | for skip in metadata['skip'].values | ||
74 | for i in range(self.num_class_images) | ||
75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
76 | self.data = [ | ||
77 | (i, c, p, n) | ||
78 | for i, c, p, n, s | ||
79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
80 | if s != "x" | ||
81 | ] | ||
60 | 82 | ||
61 | def setup(self, stage=None): | 83 | def setup(self, stage=None): |
62 | valid_set_size = int(len(self.data) * 0.2) | 84 | valid_set_size = int(len(self.data) * 0.2) |
@@ -69,6 +91,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
69 | 91 | ||
70 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, |
71 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
94 | num_class_images=self.num_class_images, | ||
72 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
73 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
74 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, |
@@ -93,6 +116,7 @@ class CSVDataset(Dataset): | |||
93 | tokenizer, | 116 | tokenizer, |
94 | instance_identifier, | 117 | instance_identifier, |
95 | class_identifier=None, | 118 | class_identifier=None, |
119 | num_class_images=2, | ||
96 | size=512, | 120 | size=512, |
97 | repeats=1, | 121 | repeats=1, |
98 | interpolation="bicubic", | 122 | interpolation="bicubic", |
@@ -103,6 +127,7 @@ class CSVDataset(Dataset): | |||
103 | self.tokenizer = tokenizer | 127 | self.tokenizer = tokenizer |
104 | self.instance_identifier = instance_identifier | 128 | self.instance_identifier = instance_identifier |
105 | self.class_identifier = class_identifier | 129 | self.class_identifier = class_identifier |
130 | self.num_class_images = num_class_images | ||
106 | self.cache = {} | 131 | self.cache = {} |
107 | 132 | ||
108 | self.num_instance_images = len(self.data) | 133 | self.num_instance_images = len(self.data) |
@@ -128,9 +153,10 @@ class CSVDataset(Dataset): | |||
128 | 153 | ||
129 | def get_example(self, i): | 154 | def get_example(self, i): |
130 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
156 | cache_key = f"{instance_image_path}_{class_image_path}" | ||
131 | 157 | ||
132 | if instance_image_path in self.cache: | 158 | if cache_key in self.cache: |
133 | return self.cache[instance_image_path] | 159 | return self.cache[cache_key] |
134 | 160 | ||
135 | example = {} | 161 | example = {} |
136 | 162 | ||
@@ -149,7 +175,7 @@ class CSVDataset(Dataset): | |||
149 | max_length=self.tokenizer.model_max_length, | 175 | max_length=self.tokenizer.model_max_length, |
150 | ).input_ids | 176 | ).input_ids |
151 | 177 | ||
152 | if self.class_identifier is not None: | 178 | if self.num_class_images != 0: |
153 | class_image = Image.open(class_image_path) | 179 | class_image = Image.open(class_image_path) |
154 | if not class_image.mode == "RGB": | 180 | if not class_image.mode == "RGB": |
155 | class_image = class_image.convert("RGB") | 181 | class_image = class_image.convert("RGB") |
diff --git a/dreambooth.py b/dreambooth.py index 0e69d79..24e6091 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -66,6 +66,12 @@ def parse_args(): | |||
66 | help="A token to use as a placeholder for the concept.", | 66 | help="A token to use as a placeholder for the concept.", |
67 | ) | 67 | ) |
68 | parser.add_argument( | 68 | parser.add_argument( |
69 | "--num_class_images", | ||
70 | type=int, | ||
71 | default=2, | ||
72 | help="How many class images to generate per training image." | ||
73 | ) | ||
74 | parser.add_argument( | ||
69 | "--repeats", | 75 | "--repeats", |
70 | type=int, | 76 | type=int, |
71 | default=1, | 77 | default=1, |
@@ -347,6 +353,7 @@ class Checkpointer: | |||
347 | scheduler=scheduler, | 353 | scheduler=scheduler, |
348 | ).to(self.accelerator.device) | 354 | ).to(self.accelerator.device) |
349 | pipeline.enable_attention_slicing() | 355 | pipeline.enable_attention_slicing() |
356 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
350 | 357 | ||
351 | train_data = self.datamodule.train_dataloader() | 358 | train_data = self.datamodule.train_dataloader() |
352 | val_data = self.datamodule.val_dataloader() | 359 | val_data = self.datamodule.val_dataloader() |
@@ -494,7 +501,7 @@ def main(): | |||
494 | pixel_values = [example["instance_images"] for example in examples] | 501 | pixel_values = [example["instance_images"] for example in examples] |
495 | 502 | ||
496 | # concat class and instance examples for prior preservation | 503 | # concat class and instance examples for prior preservation |
497 | if args.class_identifier is not None and "class_prompt_ids" in examples[0]: | 504 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: |
498 | input_ids += [example["class_prompt_ids"] for example in examples] | 505 | input_ids += [example["class_prompt_ids"] for example in examples] |
499 | pixel_values += [example["class_images"] for example in examples] | 506 | pixel_values += [example["class_images"] for example in examples] |
500 | 507 | ||
@@ -518,6 +525,7 @@ def main(): | |||
518 | instance_identifier=args.instance_identifier, | 525 | instance_identifier=args.instance_identifier, |
519 | class_identifier=args.class_identifier, | 526 | class_identifier=args.class_identifier, |
520 | class_subdir="db_cls", | 527 | class_subdir="db_cls", |
528 | num_class_images=args.num_class_images, | ||
521 | size=args.resolution, | 529 | size=args.resolution, |
522 | repeats=args.repeats, | 530 | repeats=args.repeats, |
523 | center_crop=args.center_crop, | 531 | center_crop=args.center_crop, |
@@ -528,7 +536,7 @@ def main(): | |||
528 | datamodule.prepare_data() | 536 | datamodule.prepare_data() |
529 | datamodule.setup() | 537 | datamodule.setup() |
530 | 538 | ||
531 | if args.class_identifier is not None: | 539 | if args.num_class_images != 0: |
532 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 540 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
533 | 541 | ||
534 | if len(missing_data) != 0: | 542 | if len(missing_data) != 0: |
@@ -547,6 +555,7 @@ def main(): | |||
547 | scheduler=scheduler, | 555 | scheduler=scheduler, |
548 | ).to(accelerator.device) | 556 | ).to(accelerator.device) |
549 | pipeline.enable_attention_slicing() | 557 | pipeline.enable_attention_slicing() |
558 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
550 | 559 | ||
551 | for batch in batched_data: | 560 | for batch in batched_data: |
552 | image_name = [p[1] for p in batch] | 561 | image_name = [p[1] for p in batch] |
@@ -645,11 +654,18 @@ def main(): | |||
645 | 0, | 654 | 0, |
646 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 655 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
647 | 656 | ||
648 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 657 | local_progress_bar = tqdm( |
649 | disable=not accelerator.is_local_main_process) | 658 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
659 | disable=not accelerator.is_local_main_process, | ||
660 | dynamic_ncols=True | ||
661 | ) | ||
650 | local_progress_bar.set_description("Batch X out of Y") | 662 | local_progress_bar.set_description("Batch X out of Y") |
651 | 663 | ||
652 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) | 664 | global_progress_bar = tqdm( |
665 | range(args.max_train_steps + val_steps), | ||
666 | disable=not accelerator.is_local_main_process, | ||
667 | dynamic_ncols=True | ||
668 | ) | ||
653 | global_progress_bar.set_description("Total progress") | 669 | global_progress_bar.set_description("Total progress") |
654 | 670 | ||
655 | try: | 671 | try: |
@@ -686,7 +702,7 @@ def main(): | |||
686 | # Predict the noise residual | 702 | # Predict the noise residual |
687 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 703 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
688 | 704 | ||
689 | if args.class_identifier is not None: | 705 | if args.num_class_images != 0: |
690 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 706 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
691 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 707 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
692 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 708 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
@@ -262,7 +262,10 @@ def generate(output_dir, pipeline, args): | |||
262 | 262 | ||
263 | with autocast("cuda"): | 263 | with autocast("cuda"): |
264 | for i in range(args.batch_num): | 264 | for i in range(args.batch_num): |
265 | pipeline.set_progress_bar_config(desc=f"Batch {i + 1} of {args.batch_num}") | 265 | pipeline.set_progress_bar_config( |
266 | desc=f"Batch {i + 1} of {args.batch_num}", | ||
267 | dynamic_ncols=True | ||
268 | ) | ||
266 | 269 | ||
267 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 270 | generator = torch.Generator(device="cuda").manual_seed(seed + i) |
268 | images = pipeline( | 271 | images = pipeline( |
diff --git a/textual_inversion.py b/textual_inversion.py index 11c324d..86fcdfe 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -68,16 +68,17 @@ def parse_args(): | |||
68 | help="A token to use as initializer word." | 68 | help="A token to use as initializer word." |
69 | ) | 69 | ) |
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--use_class_images", | 71 | "--num_class_images", |
72 | action="store_true", | 72 | type=int, |
73 | default=True, | 73 | default=2, |
74 | help="Include class images in the loss calculation a la Dreambooth.", | 74 | help="How many class images to generate per training image." |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--repeats", | 77 | "--repeats", |
78 | type=int, | 78 | type=int, |
79 | default=100, | 79 | default=100, |
80 | help="How many times to repeat the training data.") | 80 | help="How many times to repeat the training data." |
81 | ) | ||
81 | parser.add_argument( | 82 | parser.add_argument( |
82 | "--output_dir", | 83 | "--output_dir", |
83 | type=str, | 84 | type=str, |
@@ -204,6 +205,12 @@ def parse_args(): | |||
204 | help="How often to save a checkpoint and sample image", | 205 | help="How often to save a checkpoint and sample image", |
205 | ) | 206 | ) |
206 | parser.add_argument( | 207 | parser.add_argument( |
208 | "--sample_frequency", | ||
209 | type=int, | ||
210 | default=100, | ||
211 | help="How often to save a checkpoint and sample image", | ||
212 | ) | ||
213 | parser.add_argument( | ||
207 | "--sample_image_size", | 214 | "--sample_image_size", |
208 | type=int, | 215 | type=int, |
209 | default=512, | 216 | default=512, |
@@ -381,6 +388,7 @@ class Checkpointer: | |||
381 | scheduler=scheduler, | 388 | scheduler=scheduler, |
382 | ).to(self.accelerator.device) | 389 | ).to(self.accelerator.device) |
383 | pipeline.enable_attention_slicing() | 390 | pipeline.enable_attention_slicing() |
391 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
384 | 392 | ||
385 | train_data = self.datamodule.train_dataloader() | 393 | train_data = self.datamodule.train_dataloader() |
386 | val_data = self.datamodule.val_dataloader() | 394 | val_data = self.datamodule.val_dataloader() |
@@ -577,7 +585,7 @@ def main(): | |||
577 | pixel_values = [example["instance_images"] for example in examples] | 585 | pixel_values = [example["instance_images"] for example in examples] |
578 | 586 | ||
579 | # concat class and instance examples for prior preservation | 587 | # concat class and instance examples for prior preservation |
580 | if args.use_class_images and "class_prompt_ids" in examples[0]: | 588 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: |
581 | input_ids += [example["class_prompt_ids"] for example in examples] | 589 | input_ids += [example["class_prompt_ids"] for example in examples] |
582 | pixel_values += [example["class_images"] for example in examples] | 590 | pixel_values += [example["class_images"] for example in examples] |
583 | 591 | ||
@@ -599,8 +607,9 @@ def main(): | |||
599 | batch_size=args.train_batch_size, | 607 | batch_size=args.train_batch_size, |
600 | tokenizer=tokenizer, | 608 | tokenizer=tokenizer, |
601 | instance_identifier=args.placeholder_token, | 609 | instance_identifier=args.placeholder_token, |
602 | class_identifier=args.initializer_token if args.use_class_images else None, | 610 | class_identifier=args.initializer_token, |
603 | class_subdir="ti_cls", | 611 | class_subdir="ti_cls", |
612 | num_class_images=args.num_class_images, | ||
604 | size=args.resolution, | 613 | size=args.resolution, |
605 | repeats=args.repeats, | 614 | repeats=args.repeats, |
606 | center_crop=args.center_crop, | 615 | center_crop=args.center_crop, |
@@ -611,7 +620,7 @@ def main(): | |||
611 | datamodule.prepare_data() | 620 | datamodule.prepare_data() |
612 | datamodule.setup() | 621 | datamodule.setup() |
613 | 622 | ||
614 | if args.use_class_images: | 623 | if args.num_class_images != 0: |
615 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 624 | missing_data = [item for item in datamodule.data if not item[1].exists()] |
616 | 625 | ||
617 | if len(missing_data) != 0: | 626 | if len(missing_data) != 0: |
@@ -630,6 +639,7 @@ def main(): | |||
630 | scheduler=scheduler, | 639 | scheduler=scheduler, |
631 | ).to(accelerator.device) | 640 | ).to(accelerator.device) |
632 | pipeline.enable_attention_slicing() | 641 | pipeline.enable_attention_slicing() |
642 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
633 | 643 | ||
634 | for batch in batched_data: | 644 | for batch in batched_data: |
635 | image_name = [p[1] for p in batch] | 645 | image_name = [p[1] for p in batch] |
@@ -729,11 +739,18 @@ def main(): | |||
729 | text_encoder, | 739 | text_encoder, |
730 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 740 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
731 | 741 | ||
732 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 742 | local_progress_bar = tqdm( |
733 | disable=not accelerator.is_local_main_process) | 743 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
744 | disable=not accelerator.is_local_main_process, | ||
745 | dynamic_ncols=True | ||
746 | ) | ||
734 | local_progress_bar.set_description("Batch X out of Y") | 747 | local_progress_bar.set_description("Batch X out of Y") |
735 | 748 | ||
736 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) | 749 | global_progress_bar = tqdm( |
750 | range(args.max_train_steps + val_steps), | ||
751 | disable=not accelerator.is_local_main_process, | ||
752 | dynamic_ncols=True | ||
753 | ) | ||
737 | global_progress_bar.set_description("Total progress") | 754 | global_progress_bar.set_description("Total progress") |
738 | 755 | ||
739 | try: | 756 | try: |
@@ -744,6 +761,8 @@ def main(): | |||
744 | text_encoder.train() | 761 | text_encoder.train() |
745 | train_loss = 0.0 | 762 | train_loss = 0.0 |
746 | 763 | ||
764 | sample_checkpoint = False | ||
765 | |||
747 | for step, batch in enumerate(train_dataloader): | 766 | for step, batch in enumerate(train_dataloader): |
748 | with accelerator.accumulate(text_encoder): | 767 | with accelerator.accumulate(text_encoder): |
749 | # Convert images to latent space | 768 | # Convert images to latent space |
@@ -769,7 +788,7 @@ def main(): | |||
769 | # Predict the noise residual | 788 | # Predict the noise residual |
770 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 789 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
771 | 790 | ||
772 | if args.use_class_images: | 791 | if args.num_class_images != 0: |
773 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 792 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
774 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 793 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
775 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 794 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
@@ -812,6 +831,9 @@ def main(): | |||
812 | 831 | ||
813 | global_step += 1 | 832 | global_step += 1 |
814 | 833 | ||
834 | if global_step % args.sample_frequency == 0: | ||
835 | sample_checkpoint = True | ||
836 | |||
815 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 837 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
816 | local_progress_bar.clear() | 838 | local_progress_bar.clear() |
817 | global_progress_bar.clear() | 839 | global_progress_bar.clear() |
@@ -878,7 +900,7 @@ def main(): | |||
878 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 900 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) |
879 | min_val_loss = val_loss | 901 | min_val_loss = val_loss |
880 | 902 | ||
881 | if accelerator.is_main_process: | 903 | if sample_checkpoint and accelerator.is_main_process: |
882 | checkpointer.save_samples( | 904 | checkpointer.save_samples( |
883 | global_step + global_step_offset, | 905 | global_step + global_step_offset, |
884 | text_encoder, | 906 | text_encoder, |