diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-14 20:03:01 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-14 20:03:01 +0200 |
| commit | 6a49074dce78615bce54777fb2be3bfd0dd8f780 (patch) | |
| tree | 0f7dde5ea6b6343fb6e0a527e5ebb2940d418dce | |
| parent | Added support for Aesthetic Gradients (diff) | |
| download | textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.gz textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.tar.bz2 textual-inversion-diff-6a49074dce78615bce54777fb2be3bfd0dd8f780.zip | |
Removed aesthetic gradients; training improvements
| -rw-r--r-- | aesthetic_gradient.py | 137 | ||||
| -rw-r--r-- | dreambooth.py | 10 | ||||
| -rw-r--r-- | dreambooth_plus.py | 59 | ||||
| -rw-r--r-- | infer.py | 32 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 50 | ||||
| -rw-r--r-- | textual_inversion.py | 32 |
6 files changed, 77 insertions, 243 deletions
diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py deleted file mode 100644 index 5386d0f..0000000 --- a/aesthetic_gradient.py +++ /dev/null | |||
| @@ -1,137 +0,0 @@ | |||
| 1 | import argparse | ||
| 2 | import datetime | ||
| 3 | import logging | ||
| 4 | import json | ||
| 5 | from pathlib import Path | ||
| 6 | |||
| 7 | import torch | ||
| 8 | import torch.utils.checkpoint | ||
| 9 | from torchvision import transforms | ||
| 10 | import pandas as pd | ||
| 11 | |||
| 12 | from accelerate.logging import get_logger | ||
| 13 | from PIL import Image | ||
| 14 | from tqdm import tqdm | ||
| 15 | from transformers import CLIPModel | ||
| 16 | from slugify import slugify | ||
| 17 | |||
| 18 | logger = get_logger(__name__) | ||
| 19 | |||
| 20 | |||
| 21 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 22 | |||
| 23 | |||
| 24 | def parse_args(): | ||
| 25 | parser = argparse.ArgumentParser( | ||
| 26 | description="Simple example of a training script." | ||
| 27 | ) | ||
| 28 | parser.add_argument( | ||
| 29 | "--pretrained_model_name_or_path", | ||
| 30 | type=str, | ||
| 31 | default=None, | ||
| 32 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
| 33 | ) | ||
| 34 | parser.add_argument( | ||
| 35 | "--train_data_file", | ||
| 36 | type=str, | ||
| 37 | default=None, | ||
| 38 | help="A directory." | ||
| 39 | ) | ||
| 40 | parser.add_argument( | ||
| 41 | "--token", | ||
| 42 | type=str, | ||
| 43 | default=None, | ||
| 44 | help="A token to use as a placeholder for the concept.", | ||
| 45 | ) | ||
| 46 | parser.add_argument( | ||
| 47 | "--resolution", | ||
| 48 | type=int, | ||
| 49 | default=224, | ||
| 50 | help=( | ||
| 51 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
| 52 | " resolution" | ||
| 53 | ), | ||
| 54 | ) | ||
| 55 | parser.add_argument( | ||
| 56 | "--output_dir", | ||
| 57 | type=str, | ||
| 58 | default="output/aesthetic-gradient", | ||
| 59 | help="The output directory where the model predictions and checkpoints will be written.", | ||
| 60 | ) | ||
| 61 | parser.add_argument( | ||
| 62 | "--config", | ||
| 63 | type=str, | ||
| 64 | default=None, | ||
| 65 | help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this." | ||
| 66 | ) | ||
| 67 | |||
| 68 | args = parser.parse_args() | ||
| 69 | if args.config is not None: | ||
| 70 | with open(args.config, 'rt') as f: | ||
| 71 | args = parser.parse_args( | ||
| 72 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 73 | |||
| 74 | if args.train_data_file is None: | ||
| 75 | raise ValueError("You must specify --train_data_file") | ||
| 76 | |||
| 77 | if args.token is None: | ||
| 78 | raise ValueError("You must specify --token") | ||
| 79 | |||
| 80 | if args.output_dir is None: | ||
| 81 | raise ValueError("You must specify --output_dir") | ||
| 82 | |||
| 83 | return args | ||
| 84 | |||
| 85 | |||
| 86 | def main(): | ||
| 87 | args = parse_args() | ||
| 88 | |||
| 89 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 90 | basepath = Path(args.output_dir) | ||
| 91 | basepath.mkdir(parents=True, exist_ok=True) | ||
| 92 | target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt") | ||
| 93 | |||
| 94 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
| 95 | |||
| 96 | data_file = Path(args.train_data_file) | ||
| 97 | if not data_file.is_file(): | ||
| 98 | raise ValueError("data_file must be a file") | ||
| 99 | data_root = data_file.parent | ||
| 100 | metadata = pd.read_csv(data_file) | ||
| 101 | image_paths = [ | ||
| 102 | data_root.joinpath(item.image) | ||
| 103 | for item in metadata.itertuples() | ||
| 104 | if "skip" not in item or item.skip != "x" | ||
| 105 | ] | ||
| 106 | |||
| 107 | model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 108 | |||
| 109 | image_transforms = transforms.Compose( | ||
| 110 | [ | ||
| 111 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS), | ||
| 112 | transforms.RandomCrop(args.resolution), | ||
| 113 | transforms.ToTensor(), | ||
| 114 | transforms.Normalize([0.5], [0.5]), | ||
| 115 | ] | ||
| 116 | ) | ||
| 117 | |||
| 118 | with torch.no_grad(): | ||
| 119 | embs = [] | ||
| 120 | for path in tqdm(image_paths): | ||
| 121 | image = Image.open(path) | ||
| 122 | if not image.mode == "RGB": | ||
| 123 | image = image.convert("RGB") | ||
| 124 | image = image_transforms(image).unsqueeze(0) | ||
| 125 | emb = model.get_image_features(image) | ||
| 126 | print(f">>>> {emb.shape}") | ||
| 127 | embs.append(emb) | ||
| 128 | |||
| 129 | embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) | ||
| 130 | |||
| 131 | print(embs.shape) | ||
| 132 | |||
| 133 | torch.save(embs, target) | ||
| 134 | |||
| 135 | |||
| 136 | if __name__ == "__main__": | ||
| 137 | main() | ||
diff --git a/dreambooth.py b/dreambooth.py index 072142e..1ba8dc0 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -70,7 +70,7 @@ def parse_args(): | |||
| 70 | "--num_class_images", | 70 | "--num_class_images", |
| 71 | type=int, | 71 | type=int, |
| 72 | default=400, | 72 | default=400, |
| 73 | help="How many class images to generate per training image." | 73 | help="How many class images to generate." |
| 74 | ) | 74 | ) |
| 75 | parser.add_argument( | 75 | parser.add_argument( |
| 76 | "--repeats", | 76 | "--repeats", |
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--max_train_steps", | 113 | "--max_train_steps", |
| 114 | type=int, | 114 | type=int, |
| 115 | default=3000, | 115 | default=2000, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -341,7 +341,7 @@ class Checkpointer: | |||
| 341 | self.sample_batch_size = sample_batch_size | 341 | self.sample_batch_size = sample_batch_size |
| 342 | 342 | ||
| 343 | @torch.no_grad() | 343 | @torch.no_grad() |
| 344 | def checkpoint(self): | 344 | def save_model(self): |
| 345 | print("Saving model...") | 345 | print("Saving model...") |
| 346 | 346 | ||
| 347 | unwrapped = self.accelerator.unwrap_model( | 347 | unwrapped = self.accelerator.unwrap_model( |
| @@ -839,14 +839,14 @@ def main(): | |||
| 839 | # Create the pipeline using using the trained modules and save it. | 839 | # Create the pipeline using using the trained modules and save it. |
| 840 | if accelerator.is_main_process: | 840 | if accelerator.is_main_process: |
| 841 | print("Finished! Saving final checkpoint and resume state.") | 841 | print("Finished! Saving final checkpoint and resume state.") |
| 842 | checkpointer.checkpoint() | 842 | checkpointer.save_model() |
| 843 | 843 | ||
| 844 | accelerator.end_training() | 844 | accelerator.end_training() |
| 845 | 845 | ||
| 846 | except KeyboardInterrupt: | 846 | except KeyboardInterrupt: |
| 847 | if accelerator.is_main_process: | 847 | if accelerator.is_main_process: |
| 848 | print("Interrupted, saving checkpoint and resume state...") | 848 | print("Interrupted, saving checkpoint and resume state...") |
| 849 | checkpointer.checkpoint() | 849 | checkpointer.save_model() |
| 850 | accelerator.end_training() | 850 | accelerator.end_training() |
| 851 | quit() | 851 | quit() |
| 852 | 852 | ||
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 7996bc2..b5ec2fc 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -58,6 +58,12 @@ def parse_args(): | |||
| 58 | parser.add_argument( | 58 | parser.add_argument( |
| 59 | "--placeholder_token", | 59 | "--placeholder_token", |
| 60 | type=str, | 60 | type=str, |
| 61 | default="<*>", | ||
| 62 | help="A token to use as a placeholder for the concept.", | ||
| 63 | ) | ||
| 64 | parser.add_argument( | ||
| 65 | "--class_identifier", | ||
| 66 | type=str, | ||
| 61 | default=None, | 67 | default=None, |
| 62 | help="A token to use as a placeholder for the concept.", | 68 | help="A token to use as a placeholder for the concept.", |
| 63 | ) | 69 | ) |
| @@ -71,7 +77,7 @@ def parse_args(): | |||
| 71 | "--num_class_images", | 77 | "--num_class_images", |
| 72 | type=int, | 78 | type=int, |
| 73 | default=400, | 79 | default=400, |
| 74 | help="How many class images to generate per training image." | 80 | help="How many class images to generate." |
| 75 | ) | 81 | ) |
| 76 | parser.add_argument( | 82 | parser.add_argument( |
| 77 | "--repeats", | 83 | "--repeats", |
| @@ -112,7 +118,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 118 | parser.add_argument( |
| 113 | "--max_train_steps", | 119 | "--max_train_steps", |
| 114 | type=int, | 120 | type=int, |
| 115 | default=1600, | 121 | default=2300, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 122 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 123 | ) |
| 118 | parser.add_argument( | 124 | parser.add_argument( |
| @@ -135,7 +141,7 @@ def parse_args(): | |||
| 135 | parser.add_argument( | 141 | parser.add_argument( |
| 136 | "--learning_rate_text", | 142 | "--learning_rate_text", |
| 137 | type=float, | 143 | type=float, |
| 138 | default=5e-4, | 144 | default=5e-6, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 145 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 146 | ) |
| 141 | parser.add_argument( | 147 | parser.add_argument( |
| @@ -222,6 +228,12 @@ def parse_args(): | |||
| 222 | ), | 228 | ), |
| 223 | ) | 229 | ) |
| 224 | parser.add_argument( | 230 | parser.add_argument( |
| 231 | "--checkpoint_frequency", | ||
| 232 | type=int, | ||
| 233 | default=500, | ||
| 234 | help="How often to save a checkpoint and sample image", | ||
| 235 | ) | ||
| 236 | parser.add_argument( | ||
| 225 | "--sample_frequency", | 237 | "--sample_frequency", |
| 226 | type=int, | 238 | type=int, |
| 227 | default=100, | 239 | default=100, |
| @@ -352,7 +364,26 @@ class Checkpointer: | |||
| 352 | self.sample_batch_size = sample_batch_size | 364 | self.sample_batch_size = sample_batch_size |
| 353 | 365 | ||
| 354 | @torch.no_grad() | 366 | @torch.no_grad() |
| 355 | def checkpoint(self): | 367 | def checkpoint(self, step, postfix): |
| 368 | print("Saving checkpoint for step %d..." % step) | ||
| 369 | |||
| 370 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 371 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 372 | |||
| 373 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | ||
| 374 | |||
| 375 | # Save a checkpoint | ||
| 376 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
| 377 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
| 378 | |||
| 379 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | ||
| 380 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 381 | |||
| 382 | del unwrapped | ||
| 383 | del learned_embeds | ||
| 384 | |||
| 385 | @torch.no_grad() | ||
| 386 | def save_model(self): | ||
| 356 | print("Saving model...") | 387 | print("Saving model...") |
| 357 | 388 | ||
| 358 | unwrapped_unet = self.accelerator.unwrap_model( | 389 | unwrapped_unet = self.accelerator.unwrap_model( |
| @@ -612,7 +643,7 @@ def main(): | |||
| 612 | batch_size=args.train_batch_size, | 643 | batch_size=args.train_batch_size, |
| 613 | tokenizer=tokenizer, | 644 | tokenizer=tokenizer, |
| 614 | instance_identifier=args.placeholder_token, | 645 | instance_identifier=args.placeholder_token, |
| 615 | class_identifier=args.initializer_token, | 646 | class_identifier=args.class_identifier, |
| 616 | class_subdir="cls", | 647 | class_subdir="cls", |
| 617 | num_class_images=args.num_class_images, | 648 | num_class_images=args.num_class_images, |
| 618 | size=args.resolution, | 649 | size=args.resolution, |
| @@ -648,7 +679,7 @@ def main(): | |||
| 648 | with torch.inference_mode(): | 679 | with torch.inference_mode(): |
| 649 | for batch in batched_data: | 680 | for batch in batched_data: |
| 650 | image_name = [p.class_image_path for p in batch] | 681 | image_name = [p.class_image_path for p in batch] |
| 651 | prompt = [p.prompt.format(args.initializer_token) for p in batch] | 682 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
| 652 | nprompt = [p.nprompt for p in batch] | 683 | nprompt = [p.nprompt for p in batch] |
| 653 | 684 | ||
| 654 | images = pipeline( | 685 | images = pipeline( |
| @@ -842,6 +873,12 @@ def main(): | |||
| 842 | if global_step % args.sample_frequency == 0: | 873 | if global_step % args.sample_frequency == 0: |
| 843 | sample_checkpoint = True | 874 | sample_checkpoint = True |
| 844 | 875 | ||
| 876 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
| 877 | local_progress_bar.clear() | ||
| 878 | global_progress_bar.clear() | ||
| 879 | |||
| 880 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
| 881 | |||
| 845 | logs = { | 882 | logs = { |
| 846 | "train/loss": loss, | 883 | "train/loss": loss, |
| 847 | "lr/unet": lr_scheduler.get_last_lr()[0], | 884 | "lr/unet": lr_scheduler.get_last_lr()[0], |
| @@ -903,6 +940,9 @@ def main(): | |||
| 903 | global_progress_bar.clear() | 940 | global_progress_bar.clear() |
| 904 | 941 | ||
| 905 | if min_val_loss > val_loss: | 942 | if min_val_loss > val_loss: |
| 943 | accelerator.print( | ||
| 944 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 945 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | ||
| 906 | min_val_loss = val_loss | 946 | min_val_loss = val_loss |
| 907 | 947 | ||
| 908 | if sample_checkpoint and accelerator.is_main_process: | 948 | if sample_checkpoint and accelerator.is_main_process: |
| @@ -913,14 +953,15 @@ def main(): | |||
| 913 | # Create the pipeline using using the trained modules and save it. | 953 | # Create the pipeline using using the trained modules and save it. |
| 914 | if accelerator.is_main_process: | 954 | if accelerator.is_main_process: |
| 915 | print("Finished! Saving final checkpoint and resume state.") | 955 | print("Finished! Saving final checkpoint and resume state.") |
| 916 | checkpointer.checkpoint() | 956 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 917 | 957 | checkpointer.save_model() | |
| 918 | accelerator.end_training() | 958 | accelerator.end_training() |
| 919 | 959 | ||
| 920 | except KeyboardInterrupt: | 960 | except KeyboardInterrupt: |
| 921 | if accelerator.is_main_process: | 961 | if accelerator.is_main_process: |
| 922 | print("Interrupted, saving checkpoint and resume state...") | 962 | print("Interrupted, saving checkpoint and resume state...") |
| 923 | checkpointer.checkpoint() | 963 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 964 | checkpointer.save_model() | ||
| 924 | accelerator.end_training() | 965 | accelerator.end_training() |
| 925 | quit() | 966 | quit() |
| 926 | 967 | ||
| @@ -24,7 +24,6 @@ default_args = { | |||
| 24 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| 25 | "precision": "fp32", | 25 | "precision": "fp32", |
| 26 | "ti_embeddings_dir": "embeddings_ti", | 26 | "ti_embeddings_dir": "embeddings_ti", |
| 27 | "ag_embeddings_dir": "embeddings_ag", | ||
| 28 | "output_dir": "output/inference", | 27 | "output_dir": "output/inference", |
| 29 | "config": None, | 28 | "config": None, |
| 30 | } | 29 | } |
| @@ -78,10 +77,6 @@ def create_args_parser(): | |||
| 78 | type=str, | 77 | type=str, |
| 79 | ) | 78 | ) |
| 80 | parser.add_argument( | 79 | parser.add_argument( |
| 81 | "--ag_embeddings_dir", | ||
| 82 | type=str, | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 85 | "--output_dir", | 80 | "--output_dir", |
| 86 | type=str, | 81 | type=str, |
| 87 | ) | 82 | ) |
| @@ -211,24 +206,7 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir): | |||
| 211 | print(f"Loaded {placeholder_token}") | 206 | print(f"Loaded {placeholder_token}") |
| 212 | 207 | ||
| 213 | 208 | ||
| 214 | def load_embeddings_ag(pipeline, embeddings_dir): | 209 | def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): |
| 215 | print(f"Loading Aesthetic Gradient embeddings") | ||
| 216 | |||
| 217 | embeddings_dir = Path(embeddings_dir) | ||
| 218 | embeddings_dir.mkdir(parents=True, exist_ok=True) | ||
| 219 | |||
| 220 | for file in embeddings_dir.iterdir(): | ||
| 221 | if file.is_file(): | ||
| 222 | placeholder_token = file.stem | ||
| 223 | |||
| 224 | data = torch.load(file, map_location="cpu") | ||
| 225 | |||
| 226 | pipeline.add_aesthetic_gradient_embedding(placeholder_token, data) | ||
| 227 | |||
| 228 | print(f"Loaded {placeholder_token}") | ||
| 229 | |||
| 230 | |||
| 231 | def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype): | ||
| 232 | print("Loading Stable Diffusion pipeline...") | 210 | print("Loading Stable Diffusion pipeline...") |
| 233 | 211 | ||
| 234 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 212 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
| @@ -262,13 +240,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtyp | |||
| 262 | tokenizer=tokenizer, | 240 | tokenizer=tokenizer, |
| 263 | scheduler=scheduler, | 241 | scheduler=scheduler, |
| 264 | ) | 242 | ) |
| 265 | pipeline.aesthetic_gradient_iters = 30 | 243 | pipeline.aesthetic_gradient_iters = 20 |
| 266 | pipeline.to("cuda") | 244 | pipeline.to("cuda") |
| 267 | 245 | ||
| 268 | print("Pipeline loaded.") | 246 | print("Pipeline loaded.") |
| 269 | 247 | ||
| 270 | load_embeddings_ag(pipeline, ag_embeddings_dir) | ||
| 271 | |||
| 272 | return pipeline | 248 | return pipeline |
| 273 | 249 | ||
| 274 | 250 | ||
| @@ -288,7 +264,7 @@ def generate(output_dir, pipeline, args): | |||
| 288 | else: | 264 | else: |
| 289 | init_image = None | 265 | init_image = None |
| 290 | 266 | ||
| 291 | with torch.autocast("cuda"): | 267 | with torch.autocast("cuda"), torch.inference_mode(): |
| 292 | for i in range(args.batch_num): | 268 | for i in range(args.batch_num): |
| 293 | pipeline.set_progress_bar_config( | 269 | pipeline.set_progress_bar_config( |
| 294 | desc=f"Batch {i + 1} of {args.batch_num}", | 270 | desc=f"Batch {i + 1} of {args.batch_num}", |
| @@ -366,7 +342,7 @@ def main(): | |||
| 366 | output_dir = Path(args.output_dir) | 342 | output_dir = Path(args.output_dir) |
| 367 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 343 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 368 | 344 | ||
| 369 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) | 345 | pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype) |
| 370 | cmd_parser = create_cmd_parser() | 346 | cmd_parser = create_cmd_parser() |
| 371 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 347 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) |
| 372 | cmd_prompt.cmdloop() | 348 | cmd_prompt.cmdloop() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 1a84c8d..3e41f86 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -51,10 +51,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 51 | new_config["steps_offset"] = 1 | 51 | new_config["steps_offset"] = 1 |
| 52 | scheduler._internal_dict = FrozenDict(new_config) | 52 | scheduler._internal_dict = FrozenDict(new_config) |
| 53 | 53 | ||
| 54 | self.aesthetic_gradient_embeddings = {} | ||
| 55 | self.aesthetic_gradient_lr = 1e-4 | ||
| 56 | self.aesthetic_gradient_iters = 10 | ||
| 57 | |||
| 58 | self.register_modules( | 54 | self.register_modules( |
| 59 | vae=vae, | 55 | vae=vae, |
| 60 | text_encoder=text_encoder, | 56 | text_encoder=text_encoder, |
| @@ -63,46 +59,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 63 | scheduler=scheduler, | 59 | scheduler=scheduler, |
| 64 | ) | 60 | ) |
| 65 | 61 | ||
| 66 | def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): | 62 | def get_text_embeddings(self, text_input_ids): |
| 67 | self.aesthetic_gradient_embeddings[keyword] = tensor | 63 | return self.text_encoder(text_input_ids)[0] |
| 68 | |||
| 69 | def get_text_embeddings(self, prompt, text_input_ids): | ||
| 70 | prompt = " ".join(prompt) | ||
| 71 | |||
| 72 | embeddings = [ | ||
| 73 | embedding | ||
| 74 | for key, embedding in self.aesthetic_gradient_embeddings.items() | ||
| 75 | if key in prompt | ||
| 76 | ] | ||
| 77 | |||
| 78 | if len(embeddings) != 0: | ||
| 79 | with torch.enable_grad(): | ||
| 80 | full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||
| 81 | full_clip_model.to(self.device) | ||
| 82 | full_clip_model.text_model.train() | ||
| 83 | |||
| 84 | optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr) | ||
| 85 | |||
| 86 | for embs in embeddings: | ||
| 87 | embs = embs.clone().detach().to(self.device) | ||
| 88 | embs /= embs.norm(dim=-1, keepdim=True) | ||
| 89 | |||
| 90 | for i in range(self.aesthetic_gradient_iters): | ||
| 91 | text_embs = full_clip_model.get_text_features(text_input_ids) | ||
| 92 | text_embs /= text_embs.norm(dim=-1, keepdim=True) | ||
| 93 | sim = text_embs @ embs.T | ||
| 94 | loss = -sim | ||
| 95 | loss = loss.mean() | ||
| 96 | |||
| 97 | loss.backward() | ||
| 98 | optimizer.step() | ||
| 99 | optimizer.zero_grad() | ||
| 100 | |||
| 101 | full_clip_model.text_model.eval() | ||
| 102 | |||
| 103 | return full_clip_model.text_model(text_input_ids)[0] | ||
| 104 | else: | ||
| 105 | return self.text_encoder(text_input_ids)[0] | ||
| 106 | 64 | ||
| 107 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 65 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 108 | r""" | 66 | r""" |
| @@ -241,7 +199,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 241 | ) | 199 | ) |
| 242 | print(f"Too many tokens: {removed_text}") | 200 | print(f"Too many tokens: {removed_text}") |
| 243 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 201 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] |
| 244 | text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) | 202 | text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) |
| 245 | 203 | ||
| 246 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | 204 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 247 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | 205 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
| @@ -253,7 +211,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 253 | uncond_input = self.tokenizer( | 211 | uncond_input = self.tokenizer( |
| 254 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" | 212 | negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" |
| 255 | ) | 213 | ) |
| 256 | uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) | 214 | uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) |
| 257 | 215 | ||
| 258 | # For classifier free guidance, we need to do two forward passes. | 216 | # For classifier free guidance, we need to do two forward passes. |
| 259 | # Here we concatenate the unconditional and text embeddings into a single batch | 217 | # Here we concatenate the unconditional and text embeddings into a single batch |
diff --git a/textual_inversion.py b/textual_inversion.py index 9d2840d..6627f1f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -57,6 +57,12 @@ def parse_args(): | |||
| 57 | parser.add_argument( | 57 | parser.add_argument( |
| 58 | "--placeholder_token", | 58 | "--placeholder_token", |
| 59 | type=str, | 59 | type=str, |
| 60 | default="<*>", | ||
| 61 | help="A token to use as a placeholder for the concept.", | ||
| 62 | ) | ||
| 63 | parser.add_argument( | ||
| 64 | "--class_identifier", | ||
| 65 | type=str, | ||
| 60 | default=None, | 66 | default=None, |
| 61 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
| 62 | ) | 68 | ) |
| @@ -70,7 +76,7 @@ def parse_args(): | |||
| 70 | "--num_class_images", | 76 | "--num_class_images", |
| 71 | type=int, | 77 | type=int, |
| 72 | default=400, | 78 | default=400, |
| 73 | help="How many class images to generate per training image." | 79 | help="How many class images to generate." |
| 74 | ) | 80 | ) |
| 75 | parser.add_argument( | 81 | parser.add_argument( |
| 76 | "--repeats", | 82 | "--repeats", |
| @@ -344,12 +350,11 @@ class Checkpointer: | |||
| 344 | self.sample_batch_size = sample_batch_size | 350 | self.sample_batch_size = sample_batch_size |
| 345 | 351 | ||
| 346 | @torch.no_grad() | 352 | @torch.no_grad() |
| 347 | def checkpoint(self, step, postfix, path=None): | 353 | def checkpoint(self, step, postfix): |
| 348 | print("Saving checkpoint for step %d..." % step) | 354 | print("Saving checkpoint for step %d..." % step) |
| 349 | 355 | ||
| 350 | if path is None: | 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| 351 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
| 352 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 353 | 358 | ||
| 354 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 355 | 360 | ||
| @@ -358,10 +363,7 @@ class Checkpointer: | |||
| 358 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 363 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
| 359 | 364 | ||
| 360 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 365 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
| 361 | if path is not None: | 366 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
| 362 | torch.save(learned_embeds_dict, path) | ||
| 363 | else: | ||
| 364 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | ||
| 365 | 367 | ||
| 366 | del unwrapped | 368 | del unwrapped |
| 367 | del learned_embeds | 369 | del learned_embeds |
| @@ -595,7 +597,7 @@ def main(): | |||
| 595 | batch_size=args.train_batch_size, | 597 | batch_size=args.train_batch_size, |
| 596 | tokenizer=tokenizer, | 598 | tokenizer=tokenizer, |
| 597 | instance_identifier=args.placeholder_token, | 599 | instance_identifier=args.placeholder_token, |
| 598 | class_identifier=args.initializer_token, | 600 | class_identifier=args.class_identifier, |
| 599 | class_subdir="cls", | 601 | class_subdir="cls", |
| 600 | num_class_images=args.num_class_images, | 602 | num_class_images=args.num_class_images, |
| 601 | size=args.resolution, | 603 | size=args.resolution, |
| @@ -631,7 +633,7 @@ def main(): | |||
| 631 | with torch.inference_mode(): | 633 | with torch.inference_mode(): |
| 632 | for batch in batched_data: | 634 | for batch in batched_data: |
| 633 | image_name = [p.class_image_path for p in batch] | 635 | image_name = [p.class_image_path for p in batch] |
| 634 | prompt = [p.prompt.format(args.initializer_token) for p in batch] | 636 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
| 635 | nprompt = [p.nprompt for p in batch] | 637 | nprompt = [p.nprompt for p in batch] |
| 636 | 638 | ||
| 637 | images = pipeline( | 639 | images = pipeline( |
| @@ -898,17 +900,11 @@ def main(): | |||
| 898 | # Create the pipeline using using the trained modules and save it. | 900 | # Create the pipeline using using the trained modules and save it. |
| 899 | if accelerator.is_main_process: | 901 | if accelerator.is_main_process: |
| 900 | print("Finished! Saving final checkpoint and resume state.") | 902 | print("Finished! Saving final checkpoint and resume state.") |
| 901 | checkpointer.checkpoint( | 903 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 902 | global_step + global_step_offset, | ||
| 903 | "end", | ||
| 904 | path=f"{basepath}/learned_embeds.bin" | ||
| 905 | ) | ||
| 906 | |||
| 907 | save_resume_file(basepath, args, { | 904 | save_resume_file(basepath, args, { |
| 908 | "global_step": global_step + global_step_offset, | 905 | "global_step": global_step + global_step_offset, |
| 909 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 906 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
| 910 | }) | 907 | }) |
| 911 | |||
| 912 | accelerator.end_training() | 908 | accelerator.end_training() |
| 913 | 909 | ||
| 914 | except KeyboardInterrupt: | 910 | except KeyboardInterrupt: |
