summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-27 16:57:29 +0100
committerVolpeon <git@volpeon.ink>2022-11-27 16:57:29 +0100
commitb9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d (patch)
tree2ad3740868696fc071d8850171e6e53ccc3a7bd2
parentUpdate (diff)
downloadtextual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.gz
textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.tar.bz2
textual-inversion-diff-b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d.zip
Update
-rw-r--r--dreambooth.py54
-rw-r--r--environment.yaml4
-rw-r--r--infer.py52
3 files changed, 63 insertions, 47 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 79b3d2c..2b8a35e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -859,7 +859,14 @@ def main():
859 # Only show the progress bar once on each machine. 859 # Only show the progress bar once on each machine.
860 860
861 global_step = 0 861 global_step = 0
862 min_val_loss = np.inf 862
863 total_loss = 0.0
864 total_acc = 0.0
865
866 total_loss_val = 0.0
867 total_acc_val = 0.0
868
869 max_acc_val = 0.0
863 870
864 checkpointer = Checkpointer( 871 checkpointer = Checkpointer(
865 datamodule=datamodule, 872 datamodule=datamodule,
@@ -905,7 +912,6 @@ def main():
905 912
906 unet.train() 913 unet.train()
907 text_encoder.train() 914 text_encoder.train()
908 train_loss = 0.0
909 915
910 sample_checkpoint = False 916 sample_checkpoint = False
911 917
@@ -978,8 +984,11 @@ def main():
978 ema_unet.step(unet) 984 ema_unet.step(unet)
979 optimizer.zero_grad(set_to_none=True) 985 optimizer.zero_grad(set_to_none=True)
980 986
981 loss = loss.detach().item() 987 acc = (noise_pred == latents).float()
982 train_loss += loss 988 acc = acc.mean()
989
990 total_loss += loss.item()
991 total_acc += acc.item()
983 992
984 # Checks if the accelerator has performed an optimization step behind the scenes 993 # Checks if the accelerator has performed an optimization step behind the scenes
985 if accelerator.sync_gradients: 994 if accelerator.sync_gradients:
@@ -996,7 +1005,10 @@ def main():
996 sample_checkpoint = True 1005 sample_checkpoint = True
997 1006
998 logs = { 1007 logs = {
999 "train/loss": loss, 1008 "train/loss": total_loss / global_step,
1009 "train/acc": total_acc / global_step,
1010 "train/cur_loss": loss.item(),
1011 "train/cur_acc": acc.item(),
1000 "lr/unet": lr_scheduler.get_last_lr()[0], 1012 "lr/unet": lr_scheduler.get_last_lr()[0],
1001 "lr/text": lr_scheduler.get_last_lr()[1] 1013 "lr/text": lr_scheduler.get_last_lr()[1]
1002 } 1014 }
@@ -1010,13 +1022,10 @@ def main():
1010 if global_step >= args.max_train_steps: 1022 if global_step >= args.max_train_steps:
1011 break 1023 break
1012 1024
1013 train_loss /= len(train_dataloader)
1014
1015 accelerator.wait_for_everyone() 1025 accelerator.wait_for_everyone()
1016 1026
1017 unet.eval() 1027 unet.eval()
1018 text_encoder.eval() 1028 text_encoder.eval()
1019 val_loss = 0.0
1020 1029
1021 with torch.autocast("cuda"), torch.inference_mode(): 1030 with torch.autocast("cuda"), torch.inference_mode():
1022 for step, batch in enumerate(val_dataloader): 1031 for step, batch in enumerate(val_dataloader):
@@ -1039,28 +1048,41 @@ def main():
1039 1048
1040 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 1049 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
1041 1050
1042 loss = loss.detach().item() 1051 acc = (noise_pred == latents).float()
1043 val_loss += loss 1052 acc = acc.mean()
1053
1054 total_loss_val += loss.item()
1055 total_acc_val += acc.item()
1044 1056
1045 if accelerator.sync_gradients: 1057 if accelerator.sync_gradients:
1046 local_progress_bar.update(1) 1058 local_progress_bar.update(1)
1047 global_progress_bar.update(1) 1059 global_progress_bar.update(1)
1048 1060
1049 logs = {"val/loss": loss} 1061 logs = {
1062 "val/loss": total_loss_val / global_step,
1063 "val/acc": total_acc_val / global_step,
1064 "val/cur_loss": loss.item(),
1065 "val/cur_acc": acc.item(),
1066 }
1050 local_progress_bar.set_postfix(**logs) 1067 local_progress_bar.set_postfix(**logs)
1051 1068
1052 val_loss /= len(val_dataloader) 1069 val_step = (epoch + 1) * len(val_dataloader)
1070 avg_acc_val = total_acc_val / val_step
1071 avg_loss_val = total_loss_val / val_step
1053 1072
1054 accelerator.log({"val/loss": val_loss}, step=global_step) 1073 accelerator.log({
1074 "val/loss": avg_loss_val,
1075 "val/acc": avg_acc_val,
1076 }, step=global_step)
1055 1077
1056 local_progress_bar.clear() 1078 local_progress_bar.clear()
1057 global_progress_bar.clear() 1079 global_progress_bar.clear()
1058 1080
1059 if min_val_loss > val_loss: 1081 if avg_acc_val > max_acc_val:
1060 accelerator.print( 1082 accelerator.print(
1061 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 1083 f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}")
1062 checkpointer.save_embedding(global_step, "milestone") 1084 checkpointer.save_embedding(global_step, "milestone")
1063 min_val_loss = val_loss 1085 max_acc_val = avg_acc_val
1064 1086
1065 if sample_checkpoint and accelerator.is_main_process: 1087 if sample_checkpoint and accelerator.is_main_process:
1066 checkpointer.save_samples( 1088 checkpointer.save_samples(
diff --git a/environment.yaml b/environment.yaml
index 7aa5312..4972ebd 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,7 +11,7 @@ dependencies:
11 - pytorch=1.12.1 11 - pytorch=1.12.1
12 - torchvision=0.13.1 12 - torchvision=0.13.1
13 - pandas=1.4.3 13 - pandas=1.4.3
14 - xformers=0.0.14.dev315 14 - xformers=0.0.15.dev337
15 - pip: 15 - pip:
16 - -e . 16 - -e .
17 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 17 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
@@ -35,4 +35,4 @@ dependencies:
35 - torch-fidelity==0.3.0 35 - torch-fidelity==0.3.0
36 - torchmetrics==0.9.3 36 - torchmetrics==0.9.3
37 - transformers==4.23.1 37 - transformers==4.23.1
38 - triton==2.0.0.dev20220924 38 - triton==2.0.0.dev20221105
diff --git a/infer.py b/infer.py
index 2bf9cb3..ab5f247 100644
--- a/infer.py
+++ b/infer.py
@@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True
20 20
21default_args = { 21default_args = {
22 "model": None, 22 "model": None,
23 "scheduler": "dpmpp",
24 "precision": "fp32", 23 "precision": "fp32",
25 "ti_embeddings_dir": "embeddings_ti", 24 "ti_embeddings_dir": "embeddings_ti",
26 "output_dir": "output/inference", 25 "output_dir": "output/inference",
@@ -29,6 +28,7 @@ default_args = {
29 28
30 29
31default_cmds = { 30default_cmds = {
31 "scheduler": "dpmpp",
32 "prompt": None, 32 "prompt": None,
33 "negative_prompt": None, 33 "negative_prompt": None,
34 "image": None, 34 "image": None,
@@ -62,11 +62,6 @@ def create_args_parser():
62 type=str, 62 type=str,
63 ) 63 )
64 parser.add_argument( 64 parser.add_argument(
65 "--scheduler",
66 type=str,
67 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
68 )
69 parser.add_argument(
70 "--precision", 65 "--precision",
71 type=str, 66 type=str,
72 choices=["fp32", "fp16", "bf16"], 67 choices=["fp32", "fp16", "bf16"],
@@ -92,6 +87,11 @@ def create_cmd_parser():
92 description="Simple example of a training script." 87 description="Simple example of a training script."
93 ) 88 )
94 parser.add_argument( 89 parser.add_argument(
90 "--scheduler",
91 type=str,
92 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
93 )
94 parser.add_argument(
95 "--prompt", 95 "--prompt",
96 type=str, 96 type=str,
97 nargs="+", 97 nargs="+",
@@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
199 print(f"Loaded {placeholder_token}") 199 print(f"Loaded {placeholder_token}")
200 200
201 201
202def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): 202def create_pipeline(model, ti_embeddings_dir, dtype):
203 print("Loading Stable Diffusion pipeline...") 203 print("Loading Stable Diffusion pipeline...")
204 204
205 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 205 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
206 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 206 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
207 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 207 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
208 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 208 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
209 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
209 210
210 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) 211 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir)
211 212
212 if scheduler == "plms":
213 scheduler = PNDMScheduler(
214 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
215 )
216 elif scheduler == "klms":
217 scheduler = LMSDiscreteScheduler(
218 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
219 )
220 elif scheduler == "ddim":
221 scheduler = DDIMScheduler(
222 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
223 )
224 elif scheduler == "dpmpp":
225 scheduler = DPMSolverMultistepScheduler(
226 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
227 )
228 else:
229 scheduler = EulerAncestralDiscreteScheduler(
230 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
231 )
232
233 pipeline = VlpnStableDiffusion( 213 pipeline = VlpnStableDiffusion(
234 text_encoder=text_encoder, 214 text_encoder=text_encoder,
235 vae=vae, 215 vae=vae,
@@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args):
264 else: 244 else:
265 init_image = None 245 init_image = None
266 246
247 if args.scheduler == "plms":
248 pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config)
249 elif args.scheduler == "klms":
250 pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
251 elif args.scheduler == "ddim":
252 pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
253 elif args.scheduler == "dpmpp":
254 pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
255 elif args.scheduler == "euler_a":
256 pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
257
267 with torch.autocast("cuda"), torch.inference_mode(): 258 with torch.autocast("cuda"), torch.inference_mode():
268 for i in range(args.batch_num): 259 for i in range(args.batch_num):
269 pipeline.set_progress_bar_config( 260 pipeline.set_progress_bar_config(
@@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd):
331 generate(self.output_dir, self.pipeline, args) 322 generate(self.output_dir, self.pipeline, args)
332 except KeyboardInterrupt: 323 except KeyboardInterrupt:
333 print('Generation cancelled.') 324 print('Generation cancelled.')
325 except Exception as e:
326 print(e)
327 return
334 328
335 def do_exit(self, line): 329 def do_exit(self, line):
336 return True 330 return True
@@ -345,7 +339,7 @@ def main():
345 output_dir = Path(args.output_dir) 339 output_dir = Path(args.output_dir)
346 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 340 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
347 341
348 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) 342 pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype)
349 cmd_parser = create_cmd_parser() 343 cmd_parser = create_cmd_parser()
350 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 344 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
351 cmd_prompt.cmdloop() 345 cmd_prompt.cmdloop()