From 6a49074dce78615bce54777fb2be3bfd0dd8f780 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 14 Oct 2022 20:03:01 +0200 Subject: Removed aesthetic gradients; training improvements --- aesthetic_gradient.py | 137 --------------------- dreambooth.py | 10 +- dreambooth_plus.py | 59 +++++++-- infer.py | 32 +---- .../stable_diffusion/vlpn_stable_diffusion.py | 50 +------- textual_inversion.py | 32 +++-- 6 files changed, 77 insertions(+), 243 deletions(-) delete mode 100644 aesthetic_gradient.py diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py deleted file mode 100644 index 5386d0f..0000000 --- a/aesthetic_gradient.py +++ /dev/null @@ -1,137 +0,0 @@ -import argparse -import datetime -import logging -import json -from pathlib import Path - -import torch -import torch.utils.checkpoint -from torchvision import transforms -import pandas as pd - -from accelerate.logging import get_logger -from PIL import Image -from tqdm import tqdm -from transformers import CLIPModel -from slugify import slugify - -logger = get_logger(__name__) - - -torch.backends.cuda.matmul.allow_tf32 = True - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - default=None, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) - parser.add_argument( - "--train_data_file", - type=str, - default=None, - help="A directory." - ) - parser.add_argument( - "--token", - type=str, - default=None, - help="A token to use as a placeholder for the concept.", - ) - parser.add_argument( - "--resolution", - type=int, - default=224, - help=( - "The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution" - ), - ) - parser.add_argument( - "--output_dir", - type=str, - default="output/aesthetic-gradient", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--config", - type=str, - default=None, - help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." - ) - - args = parser.parse_args() - if args.config is not None: - with open(args.config, 'rt') as f: - args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f)["args"])) - - if args.train_data_file is None: - raise ValueError("You must specify --train_data_file") - - if args.token is None: - raise ValueError("You must specify --token") - - if args.output_dir is None: - raise ValueError("You must specify --output_dir") - - return args - - -def main(): - args = parse_args() - - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir) - basepath.mkdir(parents=True, exist_ok=True) - target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt") - - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - - data_file = Path(args.train_data_file) - if not data_file.is_file(): - raise ValueError("data_file must be a file") - data_root = data_file.parent - metadata = pd.read_csv(data_file) - image_paths = [ - data_root.joinpath(item.image) - for item in metadata.itertuples() - if "skip" not in item or item.skip != "x" - ] - - model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - - image_transforms = transforms.Compose( - [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS), - transforms.RandomCrop(args.resolution), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - with torch.no_grad(): - embs = [] - for path in tqdm(image_paths): - image = Image.open(path) - if not image.mode == "RGB": - image = image.convert("RGB") - image = image_transforms(image).unsqueeze(0) - emb = model.get_image_features(image) - print(f">>>> {emb.shape}") - embs.append(emb) - - embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) - - print(embs.shape) - - torch.save(embs, target) - - -if __name__ == "__main__": - main() diff --git a/dreambooth.py b/dreambooth.py index 072142e..1ba8dc0 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -70,7 +70,7 @@ def parse_args(): "--num_class_images", type=int, default=400, - help="How many class images to generate per training image." + help="How many class images to generate." ) parser.add_argument( "--repeats", @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=3000, + default=2000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -341,7 +341,7 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self): + def save_model(self): print("Saving model...") unwrapped = self.accelerator.unwrap_model( @@ -839,14 +839,14 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint() + checkpointer.save_model() accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint() + checkpointer.save_model() accelerator.end_training() quit() diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 7996bc2..b5ec2fc 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -58,6 +58,12 @@ def parse_args(): parser.add_argument( "--placeholder_token", type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", + type=str, default=None, help="A token to use as a placeholder for the concept.", ) @@ -71,7 +77,7 @@ def parse_args(): "--num_class_images", type=int, default=400, - help="How many class images to generate per training image." + help="How many class images to generate." ) parser.add_argument( "--repeats", @@ -112,7 +118,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1600, + default=2300, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -135,7 +141,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=5e-4, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -221,6 +227,12 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--checkpoint_frequency", + type=int, + default=500, + help="How often to save a checkpoint and sample image", + ) parser.add_argument( "--sample_frequency", type=int, @@ -352,7 +364,26 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self): + def checkpoint(self, step, postfix): + print("Saving checkpoint for step %d..." % step) + + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) + + unwrapped = self.accelerator.unwrap_model(self.text_encoder) + + # Save a checkpoint + learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] + learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} + + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + + del unwrapped + del learned_embeds + + @torch.no_grad() + def save_model(self): print("Saving model...") unwrapped_unet = self.accelerator.unwrap_model( @@ -612,7 +643,7 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, instance_identifier=args.placeholder_token, - class_identifier=args.initializer_token, + class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -648,7 +679,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.initializer_token) for p in batch] + prompt = [p.prompt.format(args.class_identifier) for p in batch] nprompt = [p.nprompt for p in batch] images = pipeline( @@ -842,6 +873,12 @@ def main(): if global_step % args.sample_frequency == 0: sample_checkpoint = True + if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: + local_progress_bar.clear() + global_progress_bar.clear() + + checkpointer.checkpoint(global_step + global_step_offset, "training") + logs = { "train/loss": loss, "lr/unet": lr_scheduler.get_last_lr()[0], @@ -903,6 +940,9 @@ def main(): global_progress_bar.clear() if min_val_loss > val_loss: + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: @@ -913,14 +953,15 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint() - + checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_model() accelerator.end_training() except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint() + checkpointer.checkpoint(global_step + global_step_offset, "end") + checkpointer.save_model() accelerator.end_training() quit() diff --git a/infer.py b/infer.py index 650c119..1a0baf5 100644 --- a/infer.py +++ b/infer.py @@ -24,7 +24,6 @@ default_args = { "scheduler": "euler_a", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", - "ag_embeddings_dir": "embeddings_ag", "output_dir": "output/inference", "config": None, } @@ -77,10 +76,6 @@ def create_args_parser(): "--ti_embeddings_dir", type=str, ) - parser.add_argument( - "--ag_embeddings_dir", - type=str, - ) parser.add_argument( "--output_dir", type=str, @@ -211,24 +206,7 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): print(f"Loaded {placeholder_token}") -def load_embeddings_ag(pipeline, embeddings_dir): - print(f"Loading Aesthetic Gradient embeddings") - - embeddings_dir = Path(embeddings_dir) - embeddings_dir.mkdir(parents=True, exist_ok=True) - - for file in embeddings_dir.iterdir(): - if file.is_file(): - placeholder_token = file.stem - - data = torch.load(file, map_location="cpu") - - pipeline.add_aesthetic_gradient_embedding(placeholder_token, data) - - print(f"Loaded {placeholder_token}") - - -def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype): +def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) @@ -262,13 +240,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtyp tokenizer=tokenizer, scheduler=scheduler, ) - pipeline.aesthetic_gradient_iters = 30 + pipeline.aesthetic_gradient_iters = 20 pipeline.to("cuda") print("Pipeline loaded.") - load_embeddings_ag(pipeline, ag_embeddings_dir) - return pipeline @@ -288,7 +264,7 @@ def generate(output_dir, pipeline, args): else: init_image = None - with torch.autocast("cuda"): + with torch.autocast("cuda"), torch.inference_mode(): for i in range(args.batch_num): pipeline.set_progress_bar_config( desc=f"Batch {i + 1} of {args.batch_num}", @@ -366,7 +342,7 @@ def main(): output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) + pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) cmd_parser = create_cmd_parser() cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 1a84c8d..3e41f86 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -51,10 +51,6 @@ class VlpnStableDiffusion(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - self.aesthetic_gradient_embeddings = {} - self.aesthetic_gradient_lr = 1e-4 - self.aesthetic_gradient_iters = 10 - self.register_modules( vae=vae, text_encoder=text_encoder, @@ -63,46 +59,8 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) - def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): - self.aesthetic_gradient_embeddings[keyword] = tensor - - def get_text_embeddings(self, prompt, text_input_ids): - prompt = " ".join(prompt) - - embeddings = [ - embedding - for key, embedding in self.aesthetic_gradient_embeddings.items() - if key in prompt - ] - - if len(embeddings) != 0: - with torch.enable_grad(): - full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") - full_clip_model.to(self.device) - full_clip_model.text_model.train() - - optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr) - - for embs in embeddings: - embs = embs.clone().detach().to(self.device) - embs /= embs.norm(dim=-1, keepdim=True) - - for i in range(self.aesthetic_gradient_iters): - text_embs = full_clip_model.get_text_features(text_input_ids) - text_embs /= text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ embs.T - loss = -sim - loss = loss.mean() - - loss.backward() - optimizer.step() - optimizer.zero_grad() - - full_clip_model.text_model.eval() - - return full_clip_model.text_model(text_input_ids)[0] - else: - return self.text_encoder(text_input_ids)[0] + def get_text_embeddings(self, text_input_ids): + return self.text_encoder(text_input_ids)[0] def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -241,7 +199,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ) print(f"Too many tokens: {removed_text}") text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) + text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -253,7 +211,7 @@ class VlpnStableDiffusion(DiffusionPipeline): uncond_input = self.tokenizer( negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) + uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch diff --git a/textual_inversion.py b/textual_inversion.py index 9d2840d..6627f1f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -57,6 +57,12 @@ def parse_args(): parser.add_argument( "--placeholder_token", type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--class_identifier", + type=str, default=None, help="A token to use as a placeholder for the concept.", ) @@ -70,7 +76,7 @@ def parse_args(): "--num_class_images", type=int, default=400, - help="How many class images to generate per training image." + help="How many class images to generate." ) parser.add_argument( "--repeats", @@ -344,12 +350,11 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self, step, postfix, path=None): + def checkpoint(self, step, postfix): print("Saving checkpoint for step %d..." % step) - if path is None: - checkpoints_path = self.output_dir.joinpath("checkpoints") - checkpoints_path.mkdir(parents=True, exist_ok=True) + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) unwrapped = self.accelerator.unwrap_model(self.text_encoder) @@ -358,10 +363,7 @@ class Checkpointer: learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) - if path is not None: - torch.save(learned_embeds_dict, path) - else: - torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) del unwrapped del learned_embeds @@ -595,7 +597,7 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, instance_identifier=args.placeholder_token, - class_identifier=args.initializer_token, + class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, @@ -631,7 +633,7 @@ def main(): with torch.inference_mode(): for batch in batched_data: image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.initializer_token) for p in batch] + prompt = [p.prompt.format(args.class_identifier) for p in batch] nprompt = [p.nprompt for p in batch] images = pipeline( @@ -898,17 +900,11 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") - checkpointer.checkpoint( - global_step + global_step_offset, - "end", - path=f"{basepath}/learned_embeds.bin" - ) - + checkpointer.checkpoint(global_step + global_step_offset, "end") save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) - accelerator.end_training() except KeyboardInterrupt: -- cgit v1.2.3-54-g00ecf