diff options
| -rw-r--r-- | data/csv.py | 13 | ||||
| -rw-r--r-- | dreambooth.py | 13 | ||||
| -rw-r--r-- | infer.py | 4 | ||||
| -rw-r--r-- | textual_inversion.py | 11 |
4 files changed, 29 insertions, 12 deletions
diff --git a/data/csv.py b/data/csv.py index 20ac992..053457b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 54 | dropout: float = 0, | 54 | dropout: float = 0, |
| 55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
| 56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
| 57 | mode: Optional[str] = None, | ||
| 57 | template_key: str = "template", | 58 | template_key: str = "template", |
| 58 | valid_set_size: Optional[int] = None, | 59 | valid_set_size: Optional[int] = None, |
| 59 | generator: Optional[torch.Generator] = None, | 60 | generator: Optional[torch.Generator] = None, |
| @@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 80 | self.repeats = repeats | 81 | self.repeats = repeats |
| 81 | self.dropout = dropout | 82 | self.dropout = dropout |
| 82 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
| 84 | self.mode = mode | ||
| 83 | self.template_key = template_key | 85 | self.template_key = template_key |
| 84 | self.interpolation = interpolation | 86 | self.interpolation = interpolation |
| 85 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size |
| @@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 99 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
| 100 | None, | 102 | None, |
| 101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 103 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
| 102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 104 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
| 103 | ) | 105 | ) |
| 104 | for item in data | 106 | for item in data |
| 105 | ] | 107 | ] |
| @@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 118 | item.instance_image_path, | 120 | item.instance_image_path, |
| 119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 121 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
| 120 | item.prompt, | 122 | item.prompt, |
| 121 | item.nprompt | 123 | item.nprompt, |
| 122 | ) | 124 | ) |
| 123 | for item in items | 125 | for item in items |
| 124 | for i in range(image_multiplier) | 126 | for i in range(image_multiplier) |
| @@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 130 | template = metadata[self.template_key] if self.template_key in metadata else {} | 132 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| 131 | items = metadata["items"] if "items" in metadata else [] | 133 | items = metadata["items"] if "items" in metadata else [] |
| 132 | 134 | ||
| 133 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 135 | if self.mode is not None: |
| 136 | items = [ | ||
| 137 | item | ||
| 138 | for item in items | ||
| 139 | if "mode" in item and self.mode in item["mode"] | ||
| 140 | ] | ||
| 134 | items = self.prepare_items(template, items) | 141 | items = self.prepare_items(template, items) |
| 135 | items = self.filter_items(items) | 142 | items = self.filter_items(items) |
| 136 | 143 | ||
diff --git a/dreambooth.py b/dreambooth.py index 96213d0..3eecf9c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -130,6 +130,12 @@ def parse_args(): | |||
| 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 131 | ) | 131 | ) |
| 132 | parser.add_argument( | 132 | parser.add_argument( |
| 133 | "--mode", | ||
| 134 | type=str, | ||
| 135 | default=None, | ||
| 136 | help="A mode to filter the dataset.", | ||
| 137 | ) | ||
| 138 | parser.add_argument( | ||
| 133 | "--seed", | 139 | "--seed", |
| 134 | type=int, | 140 | type=int, |
| 135 | default=None, | 141 | default=None, |
| @@ -284,7 +290,7 @@ def parse_args(): | |||
| 284 | parser.add_argument( | 290 | parser.add_argument( |
| 285 | "--sample_frequency", | 291 | "--sample_frequency", |
| 286 | type=int, | 292 | type=int, |
| 287 | default=100, | 293 | default=1, |
| 288 | help="How often to save a checkpoint and sample image", | 294 | help="How often to save a checkpoint and sample image", |
| 289 | ) | 295 | ) |
| 290 | parser.add_argument( | 296 | parser.add_argument( |
| @@ -759,6 +765,7 @@ def main(): | |||
| 759 | num_class_images=args.num_class_images, | 765 | num_class_images=args.num_class_images, |
| 760 | size=args.resolution, | 766 | size=args.resolution, |
| 761 | repeats=args.repeats, | 767 | repeats=args.repeats, |
| 768 | mode=args.mode, | ||
| 762 | dropout=args.tag_dropout, | 769 | dropout=args.tag_dropout, |
| 763 | center_crop=args.center_crop, | 770 | center_crop=args.center_crop, |
| 764 | template_key=args.train_data_template, | 771 | template_key=args.train_data_template, |
| @@ -1046,7 +1053,7 @@ def main(): | |||
| 1046 | unet.eval() | 1053 | unet.eval() |
| 1047 | text_encoder.eval() | 1054 | text_encoder.eval() |
| 1048 | 1055 | ||
| 1049 | with torch.autocast("cuda"), torch.inference_mode(): | 1056 | with torch.inference_mode(): |
| 1050 | for step, batch in enumerate(val_dataloader): | 1057 | for step, batch in enumerate(val_dataloader): |
| 1051 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 1058 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 1052 | latents = latents * 0.18215 | 1059 | latents = latents * 0.18215 |
| @@ -1063,8 +1070,6 @@ def main(): | |||
| 1063 | 1070 | ||
| 1064 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1071 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 1065 | 1072 | ||
| 1066 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 1067 | |||
| 1068 | # Get the target for loss depending on the prediction type | 1073 | # Get the target for loss depending on the prediction type |
| 1069 | if noise_scheduler.config.prediction_type == "epsilon": | 1074 | if noise_scheduler.config.prediction_type == "epsilon": |
| 1070 | target = noise | 1075 | target = noise |
| @@ -34,7 +34,7 @@ torch.backends.cudnn.benchmark = True | |||
| 34 | default_args = { | 34 | default_args = { |
| 35 | "model": "stabilityai/stable-diffusion-2-1", | 35 | "model": "stabilityai/stable-diffusion-2-1", |
| 36 | "precision": "fp32", | 36 | "precision": "fp32", |
| 37 | "ti_embeddings_dir": "embeddings_ti", | 37 | "ti_embeddings_dir": "embeddings", |
| 38 | "output_dir": "output/inference", | 38 | "output_dir": "output/inference", |
| 39 | "config": None, | 39 | "config": None, |
| 40 | } | 40 | } |
| @@ -190,7 +190,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
| 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
| 192 | 192 | ||
| 193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) |
| 194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") |
| 195 | 195 | ||
| 196 | pipeline = VlpnStableDiffusion( | 196 | pipeline = VlpnStableDiffusion( |
diff --git a/textual_inversion.py b/textual_inversion.py index a849d2a..e281c73 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -112,6 +112,12 @@ def parse_args(): | |||
| 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 113 | ) | 113 | ) |
| 114 | parser.add_argument( | 114 | parser.add_argument( |
| 115 | "--mode", | ||
| 116 | type=str, | ||
| 117 | default=None, | ||
| 118 | help="A mode to filter the dataset.", | ||
| 119 | ) | ||
| 120 | parser.add_argument( | ||
| 115 | "--seed", | 121 | "--seed", |
| 116 | type=int, | 122 | type=int, |
| 117 | default=None, | 123 | default=None, |
| @@ -679,6 +685,7 @@ def main(): | |||
| 679 | num_class_images=args.num_class_images, | 685 | num_class_images=args.num_class_images, |
| 680 | size=args.resolution, | 686 | size=args.resolution, |
| 681 | repeats=args.repeats, | 687 | repeats=args.repeats, |
| 688 | mode=args.mode, | ||
| 682 | dropout=args.tag_dropout, | 689 | dropout=args.tag_dropout, |
| 683 | center_crop=args.center_crop, | 690 | center_crop=args.center_crop, |
| 684 | template_key=args.train_data_template, | 691 | template_key=args.train_data_template, |
| @@ -940,7 +947,7 @@ def main(): | |||
| 940 | text_encoder.eval() | 947 | text_encoder.eval() |
| 941 | val_loss = 0.0 | 948 | val_loss = 0.0 |
| 942 | 949 | ||
| 943 | with torch.autocast("cuda"), torch.inference_mode(): | 950 | with torch.inference_mode(): |
| 944 | for step, batch in enumerate(val_dataloader): | 951 | for step, batch in enumerate(val_dataloader): |
| 945 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 952 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 946 | latents = latents * 0.18215 | 953 | latents = latents * 0.18215 |
| @@ -958,8 +965,6 @@ def main(): | |||
| 958 | 965 | ||
| 959 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 966 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 960 | 967 | ||
| 961 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
| 962 | |||
| 963 | # Get the target for loss depending on the prediction type | 968 | # Get the target for loss depending on the prediction type |
| 964 | if noise_scheduler.config.prediction_type == "epsilon": | 969 | if noise_scheduler.config.prediction_type == "epsilon": |
| 965 | target = noise | 970 | target = noise |
