From 64c79cc3e7fad49131f90fbb0648b6d5587563e5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 08:43:34 +0100 Subject: Various updated; shuffle prompt content during training --- data/csv.py | 29 ++++++++++++++++++++--------- dreambooth.py | 10 ++++++---- infer.py | 27 ++++++++++++++++++++++----- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/data/csv.py b/data/csv.py index 67ac43b..23b5299 100644 --- a/data/csv.py +++ b/data/csv.py @@ -1,6 +1,7 @@ import math import torch import json +import numpy as np from pathlib import Path import pytorch_lightning as pl from PIL import Image @@ -15,6 +16,19 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt +def shuffle_prompt(prompt: str): + def handle_block(block: str): + words = block.split(", ") + np.random.shuffle(words) + return ", ".join(words) + + prompt = prompt.split(". ") + prompt = [handle_block(b) for b in prompt] + np.random.shuffle(prompt) + prompt = ". ".join(prompt) + return prompt + + class CSVDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -190,30 +204,27 @@ class CSVDataset(Dataset): item = self.data[i % self.num_instance_images] example = {} - example["prompts"] = item.prompt example["nprompts"] = item.nprompt - example["instance_images"] = self.get_image(item.instance_image_path) - example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier) - if self.num_class_images != 0: example["class_images"] = self.get_image(item.class_image_path) - example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier) return example def __getitem__(self, i): - example = {} unprocessed_example = self.get_example(i) - example["prompts"] = unprocessed_example["prompts"] + example = {} + + example["prompts"] = shuffle_prompt(unprocessed_example["prompts"]) example["nprompts"] = unprocessed_example["nprompts"] + example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] + example["instance_prompt_ids"] = self.get_input_ids(example["prompts"], self.instance_identifier) if self.num_class_images != 0: example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] + example["class_prompt_ids"] = self.get_input_ids(example["prompts"], self.class_identifier) return example diff --git a/dreambooth.py b/dreambooth.py index ec9531e..0044c1e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -1,13 +1,11 @@ import argparse import itertools import math -import os import datetime import logging import json from pathlib import Path -import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -299,7 +297,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=20, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -613,7 +611,7 @@ def main(): ) # Freeze text_encoder and vae - freeze_params(vae.parameters()) + vae.requires_grad_(False) if len(args.placeholder_token) != 0: print(f"Adding text embeddings: {args.placeholder_token}") @@ -629,6 +627,10 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + print(f"Token ID mappings:") + for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): + print(f"- {token_id} {token}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) diff --git a/infer.py b/infer.py index 30e11cf..e3fa9e5 100644 --- a/infer.py +++ b/infer.py @@ -8,7 +8,18 @@ from pathlib import Path import torch import json from PIL import Image -from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + PNDMScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + DDIMScheduler, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler +) from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify @@ -29,7 +40,7 @@ default_args = { default_cmds = { - "scheduler": "dpmpp", + "scheduler": "dpmsm", "prompt": None, "negative_prompt": None, "image": None, @@ -38,7 +49,7 @@ default_cmds = { "height": 512, "batch_size": 1, "batch_num": 1, - "steps": 50, + "steps": 30, "guidance_scale": 7.0, "seed": None, "config": None, @@ -90,7 +101,7 @@ def create_cmd_parser(): parser.add_argument( "--scheduler", type=str, - choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], + choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], ) parser.add_argument( "--prompt", @@ -252,10 +263,16 @@ def generate(output_dir, pipeline, args): pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "ddim": pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - elif args.scheduler == "dpmpp": + elif args.scheduler == "dpmsm": pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "dpmss": + pipeline.scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config) elif args.scheduler == "euler_a": pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "kdpm2": + pipeline.scheduler = KDPM2DiscreteScheduler.from_config(pipeline.scheduler.config) + elif args.scheduler == "kdpm2_a": + pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): -- cgit v1.2.3-70-g09d2