diff options
-rw-r--r-- | data/csv.py | 13 | ||||
-rw-r--r-- | dreambooth.py | 13 | ||||
-rw-r--r-- | infer.py | 4 | ||||
-rw-r--r-- | textual_inversion.py | 11 |
4 files changed, 29 insertions, 12 deletions
diff --git a/data/csv.py b/data/csv.py index 20ac992..053457b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
54 | dropout: float = 0, | 54 | dropout: float = 0, |
55 | interpolation: str = "bicubic", | 55 | interpolation: str = "bicubic", |
56 | center_crop: bool = False, | 56 | center_crop: bool = False, |
57 | mode: Optional[str] = None, | ||
57 | template_key: str = "template", | 58 | template_key: str = "template", |
58 | valid_set_size: Optional[int] = None, | 59 | valid_set_size: Optional[int] = None, |
59 | generator: Optional[torch.Generator] = None, | 60 | generator: Optional[torch.Generator] = None, |
@@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
80 | self.repeats = repeats | 81 | self.repeats = repeats |
81 | self.dropout = dropout | 82 | self.dropout = dropout |
82 | self.center_crop = center_crop | 83 | self.center_crop = center_crop |
84 | self.mode = mode | ||
83 | self.template_key = template_key | 85 | self.template_key = template_key |
84 | self.interpolation = interpolation | 86 | self.interpolation = interpolation |
85 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size |
@@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
99 | self.data_root.joinpath(image.format(item["image"])), | 101 | self.data_root.joinpath(image.format(item["image"])), |
100 | None, | 102 | None, |
101 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 103 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
102 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) | 104 | nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")), |
103 | ) | 105 | ) |
104 | for item in data | 106 | for item in data |
105 | ] | 107 | ] |
@@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
118 | item.instance_image_path, | 120 | item.instance_image_path, |
119 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 121 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
120 | item.prompt, | 122 | item.prompt, |
121 | item.nprompt | 123 | item.nprompt, |
122 | ) | 124 | ) |
123 | for item in items | 125 | for item in items |
124 | for i in range(image_multiplier) | 126 | for i in range(image_multiplier) |
@@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
130 | template = metadata[self.template_key] if self.template_key in metadata else {} | 132 | template = metadata[self.template_key] if self.template_key in metadata else {} |
131 | items = metadata["items"] if "items" in metadata else [] | 133 | items = metadata["items"] if "items" in metadata else [] |
132 | 134 | ||
133 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 135 | if self.mode is not None: |
136 | items = [ | ||
137 | item | ||
138 | for item in items | ||
139 | if "mode" in item and self.mode in item["mode"] | ||
140 | ] | ||
134 | items = self.prepare_items(template, items) | 141 | items = self.prepare_items(template, items) |
135 | items = self.filter_items(items) | 142 | items = self.filter_items(items) |
136 | 143 | ||
diff --git a/dreambooth.py b/dreambooth.py index 96213d0..3eecf9c 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -130,6 +130,12 @@ def parse_args(): | |||
130 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 130 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
131 | ) | 131 | ) |
132 | parser.add_argument( | 132 | parser.add_argument( |
133 | "--mode", | ||
134 | type=str, | ||
135 | default=None, | ||
136 | help="A mode to filter the dataset.", | ||
137 | ) | ||
138 | parser.add_argument( | ||
133 | "--seed", | 139 | "--seed", |
134 | type=int, | 140 | type=int, |
135 | default=None, | 141 | default=None, |
@@ -284,7 +290,7 @@ def parse_args(): | |||
284 | parser.add_argument( | 290 | parser.add_argument( |
285 | "--sample_frequency", | 291 | "--sample_frequency", |
286 | type=int, | 292 | type=int, |
287 | default=100, | 293 | default=1, |
288 | help="How often to save a checkpoint and sample image", | 294 | help="How often to save a checkpoint and sample image", |
289 | ) | 295 | ) |
290 | parser.add_argument( | 296 | parser.add_argument( |
@@ -759,6 +765,7 @@ def main(): | |||
759 | num_class_images=args.num_class_images, | 765 | num_class_images=args.num_class_images, |
760 | size=args.resolution, | 766 | size=args.resolution, |
761 | repeats=args.repeats, | 767 | repeats=args.repeats, |
768 | mode=args.mode, | ||
762 | dropout=args.tag_dropout, | 769 | dropout=args.tag_dropout, |
763 | center_crop=args.center_crop, | 770 | center_crop=args.center_crop, |
764 | template_key=args.train_data_template, | 771 | template_key=args.train_data_template, |
@@ -1046,7 +1053,7 @@ def main(): | |||
1046 | unet.eval() | 1053 | unet.eval() |
1047 | text_encoder.eval() | 1054 | text_encoder.eval() |
1048 | 1055 | ||
1049 | with torch.autocast("cuda"), torch.inference_mode(): | 1056 | with torch.inference_mode(): |
1050 | for step, batch in enumerate(val_dataloader): | 1057 | for step, batch in enumerate(val_dataloader): |
1051 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 1058 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
1052 | latents = latents * 0.18215 | 1059 | latents = latents * 0.18215 |
@@ -1063,8 +1070,6 @@ def main(): | |||
1063 | 1070 | ||
1064 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 1071 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
1065 | 1072 | ||
1066 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
1067 | |||
1068 | # Get the target for loss depending on the prediction type | 1073 | # Get the target for loss depending on the prediction type |
1069 | if noise_scheduler.config.prediction_type == "epsilon": | 1074 | if noise_scheduler.config.prediction_type == "epsilon": |
1070 | target = noise | 1075 | target = noise |
@@ -34,7 +34,7 @@ torch.backends.cudnn.benchmark = True | |||
34 | default_args = { | 34 | default_args = { |
35 | "model": "stabilityai/stable-diffusion-2-1", | 35 | "model": "stabilityai/stable-diffusion-2-1", |
36 | "precision": "fp32", | 36 | "precision": "fp32", |
37 | "ti_embeddings_dir": "embeddings_ti", | 37 | "ti_embeddings_dir": "embeddings", |
38 | "output_dir": "output/inference", | 38 | "output_dir": "output/inference", |
39 | "config": None, | 39 | "config": None, |
40 | } | 40 | } |
@@ -190,7 +190,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 190 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 191 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
192 | 192 | ||
193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) | 193 | added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir)) |
194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | 194 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") |
195 | 195 | ||
196 | pipeline = VlpnStableDiffusion( | 196 | pipeline = VlpnStableDiffusion( |
diff --git a/textual_inversion.py b/textual_inversion.py index a849d2a..e281c73 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -112,6 +112,12 @@ def parse_args(): | |||
112 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 112 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
113 | ) | 113 | ) |
114 | parser.add_argument( | 114 | parser.add_argument( |
115 | "--mode", | ||
116 | type=str, | ||
117 | default=None, | ||
118 | help="A mode to filter the dataset.", | ||
119 | ) | ||
120 | parser.add_argument( | ||
115 | "--seed", | 121 | "--seed", |
116 | type=int, | 122 | type=int, |
117 | default=None, | 123 | default=None, |
@@ -679,6 +685,7 @@ def main(): | |||
679 | num_class_images=args.num_class_images, | 685 | num_class_images=args.num_class_images, |
680 | size=args.resolution, | 686 | size=args.resolution, |
681 | repeats=args.repeats, | 687 | repeats=args.repeats, |
688 | mode=args.mode, | ||
682 | dropout=args.tag_dropout, | 689 | dropout=args.tag_dropout, |
683 | center_crop=args.center_crop, | 690 | center_crop=args.center_crop, |
684 | template_key=args.train_data_template, | 691 | template_key=args.train_data_template, |
@@ -940,7 +947,7 @@ def main(): | |||
940 | text_encoder.eval() | 947 | text_encoder.eval() |
941 | val_loss = 0.0 | 948 | val_loss = 0.0 |
942 | 949 | ||
943 | with torch.autocast("cuda"), torch.inference_mode(): | 950 | with torch.inference_mode(): |
944 | for step, batch in enumerate(val_dataloader): | 951 | for step, batch in enumerate(val_dataloader): |
945 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 952 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
946 | latents = latents * 0.18215 | 953 | latents = latents * 0.18215 |
@@ -958,8 +965,6 @@ def main(): | |||
958 | 965 | ||
959 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 966 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
960 | 967 | ||
961 | model_pred, noise = accelerator.gather_for_metrics((model_pred, noise)) | ||
962 | |||
963 | # Get the target for loss depending on the prediction type | 968 | # Get the target for loss depending on the prediction type |
964 | if noise_scheduler.config.prediction_type == "epsilon": | 969 | if noise_scheduler.config.prediction_type == "epsilon": |
965 | target = noise | 970 | target = noise |