summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-15 20:30:59 +0100
committerVolpeon <git@volpeon.ink>2022-12-15 20:30:59 +0100
commit8f4d212b3833041448678ad8a44a9a327934f74a (patch)
tree667edaef8a771a171db4a5afdae1fe8d427a2593
parentMore generic datset filter (diff)
downloadtextual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.gz
textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.tar.bz2
textual-inversion-diff-8f4d212b3833041448678ad8a44a9a327934f74a.zip
Avoid increased VRAM usage on validation
-rw-r--r--data/csv.py13
-rw-r--r--dreambooth.py13
-rw-r--r--infer.py4
-rw-r--r--textual_inversion.py11
4 files changed, 29 insertions, 12 deletions
diff --git a/data/csv.py b/data/csv.py
index 20ac992..053457b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -54,6 +54,7 @@ class CSVDataModule(pl.LightningDataModule):
54 dropout: float = 0, 54 dropout: float = 0,
55 interpolation: str = "bicubic", 55 interpolation: str = "bicubic",
56 center_crop: bool = False, 56 center_crop: bool = False,
57 mode: Optional[str] = None,
57 template_key: str = "template", 58 template_key: str = "template",
58 valid_set_size: Optional[int] = None, 59 valid_set_size: Optional[int] = None,
59 generator: Optional[torch.Generator] = None, 60 generator: Optional[torch.Generator] = None,
@@ -80,6 +81,7 @@ class CSVDataModule(pl.LightningDataModule):
80 self.repeats = repeats 81 self.repeats = repeats
81 self.dropout = dropout 82 self.dropout = dropout
82 self.center_crop = center_crop 83 self.center_crop = center_crop
84 self.mode = mode
83 self.template_key = template_key 85 self.template_key = template_key
84 self.interpolation = interpolation 86 self.interpolation = interpolation
85 self.valid_set_size = valid_set_size 87 self.valid_set_size = valid_set_size
@@ -99,7 +101,7 @@ class CSVDataModule(pl.LightningDataModule):
99 self.data_root.joinpath(image.format(item["image"])), 101 self.data_root.joinpath(image.format(item["image"])),
100 None, 102 None,
101 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 103 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
102 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")) 104 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else "")),
103 ) 105 )
104 for item in data 106 for item in data
105 ] 107 ]
@@ -118,7 +120,7 @@ class CSVDataModule(pl.LightningDataModule):
118 item.instance_image_path, 120 item.instance_image_path,
119 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 121 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
120 item.prompt, 122 item.prompt,
121 item.nprompt 123 item.nprompt,
122 ) 124 )
123 for item in items 125 for item in items
124 for i in range(image_multiplier) 126 for i in range(image_multiplier)
@@ -130,7 +132,12 @@ class CSVDataModule(pl.LightningDataModule):
130 template = metadata[self.template_key] if self.template_key in metadata else {} 132 template = metadata[self.template_key] if self.template_key in metadata else {}
131 items = metadata["items"] if "items" in metadata else [] 133 items = metadata["items"] if "items" in metadata else []
132 134
133 items = [item for item in items if not "skip" in item or item["skip"] != True] 135 if self.mode is not None:
136 items = [
137 item
138 for item in items
139 if "mode" in item and self.mode in item["mode"]
140 ]
134 items = self.prepare_items(template, items) 141 items = self.prepare_items(template, items)
135 items = self.filter_items(items) 142 items = self.filter_items(items)
136 143
diff --git a/dreambooth.py b/dreambooth.py
index 96213d0..3eecf9c 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -130,6 +130,12 @@ def parse_args():
130 help="The embeddings directory where Textual Inversion embeddings are stored.", 130 help="The embeddings directory where Textual Inversion embeddings are stored.",
131 ) 131 )
132 parser.add_argument( 132 parser.add_argument(
133 "--mode",
134 type=str,
135 default=None,
136 help="A mode to filter the dataset.",
137 )
138 parser.add_argument(
133 "--seed", 139 "--seed",
134 type=int, 140 type=int,
135 default=None, 141 default=None,
@@ -284,7 +290,7 @@ def parse_args():
284 parser.add_argument( 290 parser.add_argument(
285 "--sample_frequency", 291 "--sample_frequency",
286 type=int, 292 type=int,
287 default=100, 293 default=1,
288 help="How often to save a checkpoint and sample image", 294 help="How often to save a checkpoint and sample image",
289 ) 295 )
290 parser.add_argument( 296 parser.add_argument(
@@ -759,6 +765,7 @@ def main():
759 num_class_images=args.num_class_images, 765 num_class_images=args.num_class_images,
760 size=args.resolution, 766 size=args.resolution,
761 repeats=args.repeats, 767 repeats=args.repeats,
768 mode=args.mode,
762 dropout=args.tag_dropout, 769 dropout=args.tag_dropout,
763 center_crop=args.center_crop, 770 center_crop=args.center_crop,
764 template_key=args.train_data_template, 771 template_key=args.train_data_template,
@@ -1046,7 +1053,7 @@ def main():
1046 unet.eval() 1053 unet.eval()
1047 text_encoder.eval() 1054 text_encoder.eval()
1048 1055
1049 with torch.autocast("cuda"), torch.inference_mode(): 1056 with torch.inference_mode():
1050 for step, batch in enumerate(val_dataloader): 1057 for step, batch in enumerate(val_dataloader):
1051 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 1058 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
1052 latents = latents * 0.18215 1059 latents = latents * 0.18215
@@ -1063,8 +1070,6 @@ def main():
1063 1070
1064 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1071 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1065 1072
1066 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
1067
1068 # Get the target for loss depending on the prediction type 1073 # Get the target for loss depending on the prediction type
1069 if noise_scheduler.config.prediction_type == "epsilon": 1074 if noise_scheduler.config.prediction_type == "epsilon":
1070 target = noise 1075 target = noise
diff --git a/infer.py b/infer.py
index efeb24d..420cb83 100644
--- a/infer.py
+++ b/infer.py
@@ -34,7 +34,7 @@ torch.backends.cudnn.benchmark = True
34default_args = { 34default_args = {
35 "model": "stabilityai/stable-diffusion-2-1", 35 "model": "stabilityai/stable-diffusion-2-1",
36 "precision": "fp32", 36 "precision": "fp32",
37 "ti_embeddings_dir": "embeddings_ti", 37 "ti_embeddings_dir": "embeddings",
38 "output_dir": "output/inference", 38 "output_dir": "output/inference",
39 "config": None, 39 "config": None,
40} 40}
@@ -190,7 +190,7 @@ def create_pipeline(model, embeddings_dir, dtype):
190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 190 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 191 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
192 192
193 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir) 193 added_tokens = load_text_embeddings(tokenizer, text_encoder, Path(embeddings_dir))
194 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 194 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
195 195
196 pipeline = VlpnStableDiffusion( 196 pipeline = VlpnStableDiffusion(
diff --git a/textual_inversion.py b/textual_inversion.py
index a849d2a..e281c73 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -112,6 +112,12 @@ def parse_args():
112 help="The embeddings directory where Textual Inversion embeddings are stored.", 112 help="The embeddings directory where Textual Inversion embeddings are stored.",
113 ) 113 )
114 parser.add_argument( 114 parser.add_argument(
115 "--mode",
116 type=str,
117 default=None,
118 help="A mode to filter the dataset.",
119 )
120 parser.add_argument(
115 "--seed", 121 "--seed",
116 type=int, 122 type=int,
117 default=None, 123 default=None,
@@ -679,6 +685,7 @@ def main():
679 num_class_images=args.num_class_images, 685 num_class_images=args.num_class_images,
680 size=args.resolution, 686 size=args.resolution,
681 repeats=args.repeats, 687 repeats=args.repeats,
688 mode=args.mode,
682 dropout=args.tag_dropout, 689 dropout=args.tag_dropout,
683 center_crop=args.center_crop, 690 center_crop=args.center_crop,
684 template_key=args.train_data_template, 691 template_key=args.train_data_template,
@@ -940,7 +947,7 @@ def main():
940 text_encoder.eval() 947 text_encoder.eval()
941 val_loss = 0.0 948 val_loss = 0.0
942 949
943 with torch.autocast("cuda"), torch.inference_mode(): 950 with torch.inference_mode():
944 for step, batch in enumerate(val_dataloader): 951 for step, batch in enumerate(val_dataloader):
945 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 952 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
946 latents = latents * 0.18215 953 latents = latents * 0.18215
@@ -958,8 +965,6 @@ def main():
958 965
959 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 966 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
960 967
961 model_pred, noise = accelerator.gather_for_metrics((model_pred, noise))
962
963 # Get the target for loss depending on the prediction type 968 # Get the target for loss depending on the prediction type
964 if noise_scheduler.config.prediction_type == "epsilon": 969 if noise_scheduler.config.prediction_type == "epsilon":
965 target = noise 970 target = noise