From 64c79cc3e7fad49131f90fbb0648b6d5587563e5 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 10 Dec 2022 08:43:34 +0100
Subject: Various updated; shuffle prompt content during training

---
 data/csv.py   | 29 ++++++++++++++++++++---------
 dreambooth.py | 10 ++++++----
 infer.py      | 27 ++++++++++++++++++++++-----
 3 files changed, 48 insertions(+), 18 deletions(-)

diff --git a/data/csv.py b/data/csv.py
index 67ac43b..23b5299 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,6 +1,7 @@
 import math
 import torch
 import json
+import numpy as np
 from pathlib import Path
 import pytorch_lightning as pl
 from PIL import Image
@@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
     return {"content": prompt} if isinstance(prompt, str) else prompt
 
 
+def shuffle_prompt(prompt: str):
+    def handle_block(block: str):
+        words = block.split(", ")
+        np.random.shuffle(words)
+        return ", ".join(words)
+
+    prompt = prompt.split(". ")
+    prompt = [handle_block(b) for b in prompt]
+    np.random.shuffle(prompt)
+    prompt = ". ".join(prompt)
+    return prompt
+
+
 class CSVDataItem(NamedTuple):
     instance_image_path: Path
     class_image_path: Path
@@ -190,30 +204,27 @@ class CSVDataset(Dataset):
         item = self.data[i % self.num_instance_images]
 
         example = {}
-
         example["prompts"] = item.prompt
         example["nprompts"] = item.nprompt
-
         example["instance_images"] = self.get_image(item.instance_image_path)
-        example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier)
-
         if self.num_class_images != 0:
             example["class_images"] = self.get_image(item.class_image_path)
-            example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier)
 
         return example
 
     def __getitem__(self, i):
-        example = {}
         unprocessed_example = self.get_example(i)
 
-        example["prompts"] = unprocessed_example["prompts"]
+        example = {}
+
+        example["prompts"] = shuffle_prompt(unprocessed_example["prompts"])
         example["nprompts"] = unprocessed_example["nprompts"]
+
         example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
-        example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
+        example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier)
 
         if self.num_class_images != 0:
             example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
-            example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
+            example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier)
 
         return example
diff --git a/dreambooth.py b/dreambooth.py
index ec9531e..0044c1e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -1,13 +1,11 @@
 import argparse
 import itertools
 import math
-import os
 import datetime
 import logging
 import json
 from pathlib import Path
 
-import numpy as np
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
@@ -299,7 +297,7 @@ def parse_args():
     parser.add_argument(
         "--sample_steps",
         type=int,
-        default=20,
+        default=15,
         help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
     )
     parser.add_argument(
@@ -613,7 +611,7 @@ def main():
         )
 
     # Freeze text_encoder and vae
-    freeze_params(vae.parameters())
+    vae.requires_grad_(False)
 
     if len(args.placeholder_token) != 0:
         print(f"Adding text embeddings: {args.placeholder_token}")
@@ -629,6 +627,10 @@ def main():
 
         placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
 
+        print(f"Token ID mappings:")
+        for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
+            print(f"- {token_id} {token}")
+
         # Resize the token embeddings as we are adding new special tokens to the tokenizer
         text_encoder.resize_token_embeddings(len(tokenizer))
 
diff --git a/infer.py b/infer.py
index 30e11cf..e3fa9e5 100644
--- a/infer.py
+++ b/infer.py
@@ -8,7 +8,18 @@ from pathlib import Path
 import torch
 import json
 from PIL import Image
-from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler
+from diffusers import (
+    AutoencoderKL,
+    UNet2DConditionModel,
+    PNDMScheduler,
+    DPMSolverMultistepScheduler,
+    DPMSolverSinglestepScheduler,
+    DDIMScheduler,
+    LMSDiscreteScheduler,
+    EulerAncestralDiscreteScheduler,
+    KDPM2DiscreteScheduler,
+    KDPM2AncestralDiscreteScheduler
+)
 from transformers import CLIPTextModel, CLIPTokenizer
 from slugify import slugify
 
@@ -29,7 +40,7 @@ default_args = {
 
 
 default_cmds = {
-    "scheduler": "dpmpp",
+    "scheduler": "dpmsm",
     "prompt": None,
     "negative_prompt": None,
     "image": None,
@@ -38,7 +49,7 @@ default_cmds = {
     "height": 512,
     "batch_size": 1,
     "batch_num": 1,
-    "steps": 50,
+    "steps": 30,
     "guidance_scale": 7.0,
     "seed": None,
     "config": None,
@@ -90,7 +101,7 @@ def create_cmd_parser():
     parser.add_argument(
         "--scheduler",
         type=str,
-        choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
+        choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
     )
     parser.add_argument(
         "--prompt",
@@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args):
         pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
     elif args.scheduler == "ddim":
         pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
-    elif args.scheduler == "dpmpp":
+    elif args.scheduler == "dpmsm":
         pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+    elif args.scheduler == "dpmss":
+        pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config)
     elif args.scheduler == "euler_a":
         pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
+    elif args.scheduler == "kdpm2":
+        pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config)
+    elif args.scheduler == "kdpm2_a":
+        pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
 
     with torch.autocast("cuda"), torch.inference_mode():
         for i in range(args.batch_num):
-- 
cgit v1.2.3-70-g09d2