From b9a04c32b8efe5f0a9bc5d369cbf3c5bcc12d00d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 16:57:29 +0100 Subject: Update --- dreambooth.py | 54 ++++++++++++++++++++++++++++++++++++++---------------- environment.yaml | 4 ++-- infer.py | 52 +++++++++++++++++++++++----------------------------- 3 files changed, 63 insertions(+), 47 deletions(-) diff --git a/dreambooth.py b/dreambooth.py index 79b3d2c..2b8a35e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -859,7 +859,14 @@ def main(): # Only show the progress bar once on each machine. global_step = 0 - min_val_loss = np.inf + + total_loss = 0.0 + total_acc = 0.0 + + total_loss_val = 0.0 + total_acc_val = 0.0 + + max_acc_val = 0.0 checkpointer = Checkpointer( datamodule=datamodule, @@ -905,7 +912,6 @@ def main(): unet.train() text_encoder.train() - train_loss = 0.0 sample_checkpoint = False @@ -978,8 +984,11 @@ def main(): ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) - loss = loss.detach().item() - train_loss += loss + acc = (noise_pred == latents).float() + acc = acc.mean() + + total_loss += loss.item() + total_acc += acc.item() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -996,7 +1005,10 @@ def main(): sample_checkpoint = True logs = { - "train/loss": loss, + "train/loss": total_loss / global_step, + "train/acc": total_acc / global_step, + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), "lr/unet": lr_scheduler.get_last_lr()[0], "lr/text": lr_scheduler.get_last_lr()[1] } @@ -1010,13 +1022,10 @@ def main(): if global_step >= args.max_train_steps: break - train_loss /= len(train_dataloader) - accelerator.wait_for_everyone() unet.eval() text_encoder.eval() - val_loss = 0.0 with torch.autocast("cuda"), torch.inference_mode(): for step, batch in enumerate(val_dataloader): @@ -1039,28 +1048,41 @@ def main(): loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - loss = loss.detach().item() - val_loss += loss + acc = (noise_pred == latents).float() + acc = acc.mean() + + total_loss_val += loss.item() + total_acc_val += acc.item() if accelerator.sync_gradients: local_progress_bar.update(1) global_progress_bar.update(1) - logs = {"val/loss": loss} + logs = { + "val/loss": total_loss_val / global_step, + "val/acc": total_acc_val / global_step, + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } local_progress_bar.set_postfix(**logs) - val_loss /= len(val_dataloader) + val_step = (epoch + 1) * len(val_dataloader) + avg_acc_val = total_acc_val / val_step + avg_loss_val = total_loss_val / val_step - accelerator.log({"val/loss": val_loss}, step=global_step) + accelerator.log({ + "val/loss": avg_loss_val, + "val/acc": avg_acc_val, + }, step=global_step) local_progress_bar.clear() global_progress_bar.clear() - if min_val_loss > val_loss: + if avg_acc_val > max_acc_val: accelerator.print( - f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") checkpointer.save_embedding(global_step, "milestone") - min_val_loss = val_loss + max_acc_val = avg_acc_val if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( diff --git a/environment.yaml b/environment.yaml index 7aa5312..4972ebd 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,7 +11,7 @@ dependencies: - pytorch=1.12.1 - torchvision=0.13.1 - pandas=1.4.3 - - xformers=0.0.14.dev315 + - xformers=0.0.15.dev337 - pip: - -e . - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers @@ -35,4 +35,4 @@ dependencies: - torch-fidelity==0.3.0 - torchmetrics==0.9.3 - transformers==4.23.1 - - triton==2.0.0.dev20220924 + - triton==2.0.0.dev20221105 diff --git a/infer.py b/infer.py index 2bf9cb3..ab5f247 100644 --- a/infer.py +++ b/infer.py @@ -20,7 +20,6 @@ torch.backends.cuda.matmul.allow_tf32 = True default_args = { "model": None, - "scheduler": "dpmpp", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "output_dir": "output/inference", @@ -29,6 +28,7 @@ default_args = { default_cmds = { + "scheduler": "dpmpp", "prompt": None, "negative_prompt": None, "image": None, @@ -61,11 +61,6 @@ def create_args_parser(): "--model", type=str, ) - parser.add_argument( - "--scheduler", - type=str, - choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], - ) parser.add_argument( "--precision", type=str, @@ -91,6 +86,11 @@ def create_cmd_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) + parser.add_argument( + "--scheduler", + type=str, + choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], + ) parser.add_argument( "--prompt", type=str, @@ -199,37 +199,17 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): print(f"Loaded {placeholder_token}") -def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): +def create_pipeline(model, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) + scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir) - if scheduler == "plms": - scheduler = PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) - elif scheduler == "klms": - scheduler = LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - elif scheduler == "ddim": - scheduler = DDIMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False - ) - elif scheduler == "dpmpp": - scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - else: - scheduler = EulerAncestralDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) - pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, @@ -264,6 +244,17 @@ def generate(output_dir, pipeline, args): else: init_image = None + if args.scheduler == "plms": + pipeline.scheduler = PNDMScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "klms": + pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "ddim": + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "dpmpp": + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "euler_a": + pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): pipeline.set_progress_bar_config( @@ -331,6 +322,9 @@ class CmdParse(cmd.Cmd): generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: print('Generation cancelled.') + except Exception as e: + print(e) + return def do_exit(self, line): return True @@ -345,7 +339,7 @@ def main(): output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) + pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() -- cgit v1.2.3-70-g09d2