diff options
| -rw-r--r-- | data/csv.py | 162 | ||||
| -rw-r--r-- | dreambooth.py | 75 | ||||
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | infer.py | 12 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 5 | ||||
| -rw-r--r-- | textual_inversion.py | 57 |
6 files changed, 169 insertions, 144 deletions
diff --git a/data/csv.py b/data/csv.py index dcaf7d3..8637ac1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -1,27 +1,38 @@ | |||
| 1 | import math | ||
| 1 | import pandas as pd | 2 | import pandas as pd |
| 2 | from pathlib import Path | 3 | from pathlib import Path |
| 3 | import pytorch_lightning as pl | 4 | import pytorch_lightning as pl |
| 4 | from PIL import Image | 5 | from PIL import Image |
| 5 | from torch.utils.data import Dataset, DataLoader, random_split | 6 | from torch.utils.data import Dataset, DataLoader, random_split |
| 6 | from torchvision import transforms | 7 | from torchvision import transforms |
| 8 | from typing import NamedTuple, List | ||
| 9 | |||
| 10 | |||
| 11 | class CSVDataItem(NamedTuple): | ||
| 12 | instance_image_path: Path | ||
| 13 | class_image_path: Path | ||
| 14 | prompt: str | ||
| 15 | nprompt: str | ||
| 7 | 16 | ||
| 8 | 17 | ||
| 9 | class CSVDataModule(pl.LightningDataModule): | 18 | class CSVDataModule(pl.LightningDataModule): |
| 10 | def __init__(self, | 19 | def __init__( |
| 11 | batch_size, | 20 | self, |
| 12 | data_file, | 21 | batch_size, |
| 13 | tokenizer, | 22 | data_file, |
| 14 | instance_identifier, | 23 | tokenizer, |
| 15 | class_identifier=None, | 24 | instance_identifier, |
| 16 | class_subdir="db_cls", | 25 | class_identifier=None, |
| 17 | num_class_images=2, | 26 | class_subdir="db_cls", |
| 18 | size=512, | 27 | num_class_images=100, |
| 19 | repeats=100, | 28 | size=512, |
| 20 | interpolation="bicubic", | 29 | repeats=100, |
| 21 | center_crop=False, | 30 | interpolation="bicubic", |
| 22 | valid_set_size=None, | 31 | center_crop=False, |
| 23 | generator=None, | 32 | valid_set_size=None, |
| 24 | collate_fn=None): | 33 | generator=None, |
| 34 | collate_fn=None | ||
| 35 | ): | ||
| 25 | super().__init__() | 36 | super().__init__() |
| 26 | 37 | ||
| 27 | self.data_file = Path(data_file) | 38 | self.data_file = Path(data_file) |
| @@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 46 | self.collate_fn = collate_fn | 57 | self.collate_fn = collate_fn |
| 47 | self.batch_size = batch_size | 58 | self.batch_size = batch_size |
| 48 | 59 | ||
| 60 | def prepare_subdata(self, data, num_class_images=1): | ||
| 61 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
| 62 | |||
| 63 | return [ | ||
| 64 | CSVDataItem( | ||
| 65 | self.data_root.joinpath(item.image), | ||
| 66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | ||
| 67 | item.prompt, | ||
| 68 | item.nprompt if "nprompt" in item else "" | ||
| 69 | ) | ||
| 70 | for item in data | ||
| 71 | if "skip" not in item or item.skip != "x" | ||
| 72 | for i in range(image_multiplier) | ||
| 73 | ] | ||
| 74 | |||
| 49 | def prepare_data(self): | 75 | def prepare_data(self): |
| 50 | metadata = pd.read_csv(self.data_file) | 76 | metadata = pd.read_csv(self.data_file) |
| 51 | instance_image_paths = [ | 77 | metadata = list(metadata.itertuples()) |
| 52 | self.data_root.joinpath(f) | 78 | num_images = len(metadata) |
| 53 | for f in metadata['image'].values | ||
| 54 | for i in range(self.num_class_images) | ||
| 55 | ] | ||
| 56 | class_image_paths = [ | ||
| 57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") | ||
| 58 | for f in metadata['image'].values | ||
| 59 | for i in range(self.num_class_images) | ||
| 60 | ] | ||
| 61 | prompts = [ | ||
| 62 | prompt | ||
| 63 | for prompt in metadata['prompt'].values | ||
| 64 | for i in range(self.num_class_images) | ||
| 65 | ] | ||
| 66 | nprompts = [ | ||
| 67 | nprompt | ||
| 68 | for nprompt in metadata['nprompt'].values | ||
| 69 | for i in range(self.num_class_images) | ||
| 70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
| 71 | skips = [ | ||
| 72 | skip | ||
| 73 | for skip in metadata['skip'].values | ||
| 74 | for i in range(self.num_class_images) | ||
| 75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
| 76 | self.data = [ | ||
| 77 | (i, c, p, n) | ||
| 78 | for i, c, p, n, s | ||
| 79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
| 80 | if s != "x" | ||
| 81 | ] | ||
| 82 | 79 | ||
| 83 | def setup(self, stage=None): | 80 | valid_set_size = int(num_images * 0.2) |
| 84 | valid_set_size = int(len(self.data) * 0.2) | ||
| 85 | if self.valid_set_size: | 81 | if self.valid_set_size: |
| 86 | valid_set_size = min(valid_set_size, self.valid_set_size) | 82 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 87 | valid_set_size = max(valid_set_size, 1) | 83 | valid_set_size = max(valid_set_size, 1) |
| 88 | train_set_size = len(self.data) - valid_set_size | 84 | train_set_size = num_images - valid_set_size |
| 89 | 85 | ||
| 90 | self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) | 86 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) |
| 91 | 87 | ||
| 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 88 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) |
| 89 | self.data_val = self.prepare_subdata(data_val) | ||
| 90 | |||
| 91 | def setup(self, stage=None): | ||
| 92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | ||
| 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 94 | num_class_images=self.num_class_images, | 94 | num_class_images=self.num_class_images, |
| 95 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
| 96 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
| 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, |
| 98 | instance_identifier=self.instance_identifier, | 98 | instance_identifier=self.instance_identifier, |
| 99 | size=self.size, interpolation=self.interpolation, | 99 | size=self.size, interpolation=self.interpolation, |
| 100 | center_crop=self.center_crop, repeats=self.repeats) | 100 | center_crop=self.center_crop, repeats=self.repeats) |
| 101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, | 101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, | 103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| 104 | pin_memory=True, collate_fn=self.collate_fn) | 104 | pin_memory=True, collate_fn=self.collate_fn) |
| 105 | 105 | ||
| 106 | def train_dataloader(self): | 106 | def train_dataloader(self): |
| @@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 111 | 111 | ||
| 112 | 112 | ||
| 113 | class CSVDataset(Dataset): | 113 | class CSVDataset(Dataset): |
| 114 | def __init__(self, | 114 | def __init__( |
| 115 | data, | 115 | self, |
| 116 | tokenizer, | 116 | data: List[CSVDataItem], |
| 117 | instance_identifier, | 117 | tokenizer, |
| 118 | class_identifier=None, | 118 | instance_identifier, |
| 119 | num_class_images=2, | 119 | batch_size=1, |
| 120 | size=512, | 120 | class_identifier=None, |
| 121 | repeats=1, | 121 | num_class_images=0, |
| 122 | interpolation="bicubic", | 122 | size=512, |
| 123 | center_crop=False, | 123 | repeats=1, |
| 124 | ): | 124 | interpolation="bicubic", |
| 125 | center_crop=False, | ||
| 126 | ): | ||
| 125 | 127 | ||
| 126 | self.data = data | 128 | self.data = data |
| 127 | self.tokenizer = tokenizer | 129 | self.tokenizer = tokenizer |
| 130 | self.batch_size = batch_size | ||
| 128 | self.instance_identifier = instance_identifier | 131 | self.instance_identifier = instance_identifier |
| 129 | self.class_identifier = class_identifier | 132 | self.class_identifier = class_identifier |
| 130 | self.num_class_images = num_class_images | 133 | self.num_class_images = num_class_images |
| 131 | self.cache = {} | 134 | self.cache = {} |
| 135 | self.image_cache = {} | ||
| 132 | 136 | ||
| 133 | self.num_instance_images = len(self.data) | 137 | self.num_instance_images = len(self.data) |
| 134 | self._length = self.num_instance_images * repeats | 138 | self._length = self.num_instance_images * repeats |
| @@ -149,46 +153,50 @@ class CSVDataset(Dataset): | |||
| 149 | ) | 153 | ) |
| 150 | 154 | ||
| 151 | def __len__(self): | 155 | def __len__(self): |
| 152 | return self._length | 156 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 153 | 157 | ||
| 154 | def get_example(self, i): | 158 | def get_example(self, i): |
| 155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 159 | item = self.data[i % self.num_instance_images] |
| 156 | cache_key = f"{instance_image_path}_{class_image_path}" | 160 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
| 157 | 161 | ||
| 158 | if cache_key in self.cache: | 162 | if cache_key in self.cache: |
| 159 | return self.cache[cache_key] | 163 | return self.cache[cache_key] |
| 160 | 164 | ||
| 161 | example = {} | 165 | example = {} |
| 162 | 166 | ||
| 163 | example["prompts"] = prompt | 167 | example["prompts"] = item.prompt |
| 164 | example["nprompts"] = nprompt | 168 | example["nprompts"] = item.nprompt |
| 165 | 169 | ||
| 166 | instance_image = Image.open(instance_image_path) | 170 | if item.instance_image_path in self.image_cache: |
| 167 | if not instance_image.mode == "RGB": | 171 | instance_image = self.image_cache[item.instance_image_path] |
| 168 | instance_image = instance_image.convert("RGB") | 172 | else: |
| 173 | instance_image = Image.open(item.instance_image_path) | ||
| 174 | if not instance_image.mode == "RGB": | ||
| 175 | instance_image = instance_image.convert("RGB") | ||
| 176 | self.image_cache[item.instance_image_path] = instance_image | ||
| 169 | 177 | ||
| 170 | example["instance_images"] = instance_image | 178 | example["instance_images"] = instance_image |
| 171 | example["instance_prompt_ids"] = self.tokenizer( | 179 | example["instance_prompt_ids"] = self.tokenizer( |
| 172 | prompt.format(self.instance_identifier), | 180 | item.prompt.format(self.instance_identifier), |
| 173 | padding="do_not_pad", | 181 | padding="do_not_pad", |
| 174 | truncation=True, | 182 | truncation=True, |
| 175 | max_length=self.tokenizer.model_max_length, | 183 | max_length=self.tokenizer.model_max_length, |
| 176 | ).input_ids | 184 | ).input_ids |
| 177 | 185 | ||
| 178 | if self.num_class_images != 0: | 186 | if self.num_class_images != 0: |
| 179 | class_image = Image.open(class_image_path) | 187 | class_image = Image.open(item.class_image_path) |
| 180 | if not class_image.mode == "RGB": | 188 | if not class_image.mode == "RGB": |
| 181 | class_image = class_image.convert("RGB") | 189 | class_image = class_image.convert("RGB") |
| 182 | 190 | ||
| 183 | example["class_images"] = class_image | 191 | example["class_images"] = class_image |
| 184 | example["class_prompt_ids"] = self.tokenizer( | 192 | example["class_prompt_ids"] = self.tokenizer( |
| 185 | prompt.format(self.class_identifier), | 193 | item.prompt.format(self.class_identifier), |
| 186 | padding="do_not_pad", | 194 | padding="do_not_pad", |
| 187 | truncation=True, | 195 | truncation=True, |
| 188 | max_length=self.tokenizer.model_max_length, | 196 | max_length=self.tokenizer.model_max_length, |
| 189 | ).input_ids | 197 | ).input_ids |
| 190 | 198 | ||
| 191 | self.cache[instance_image_path] = example | 199 | self.cache[item.instance_image_path] = example |
| 192 | return example | 200 | return example |
| 193 | 201 | ||
| 194 | def __getitem__(self, i): | 202 | def __getitem__(self, i): |
diff --git a/dreambooth.py b/dreambooth.py index a26bea7..7b61c45 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -3,6 +3,7 @@ import math | |||
| 3 | import os | 3 | import os |
| 4 | import datetime | 4 | import datetime |
| 5 | import logging | 5 | import logging |
| 6 | import json | ||
| 6 | from pathlib import Path | 7 | from pathlib import Path |
| 7 | 8 | ||
| 8 | import numpy as np | 9 | import numpy as np |
| @@ -21,7 +22,6 @@ from tqdm.auto import tqdm | |||
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 23 | from slugify import slugify |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | import json | ||
| 25 | 25 | ||
| 26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
| 27 | 27 | ||
| @@ -68,7 +68,7 @@ def parse_args(): | |||
| 68 | parser.add_argument( | 68 | parser.add_argument( |
| 69 | "--num_class_images", | 69 | "--num_class_images", |
| 70 | type=int, | 70 | type=int, |
| 71 | default=4, | 71 | default=200, |
| 72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." |
| 73 | ) | 73 | ) |
| 74 | parser.add_argument( | 74 | parser.add_argument( |
| @@ -140,7 +140,7 @@ def parse_args(): | |||
| 140 | parser.add_argument( | 140 | parser.add_argument( |
| 141 | "--lr_scheduler", | 141 | "--lr_scheduler", |
| 142 | type=str, | 142 | type=str, |
| 143 | default="constant", | 143 | default="linear", |
| 144 | help=( | 144 | help=( |
| 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 146 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
| @@ -199,6 +199,12 @@ def parse_args(): | |||
| 199 | help="For distributed training: local_rank" | 199 | help="For distributed training: local_rank" |
| 200 | ) | 200 | ) |
| 201 | parser.add_argument( | 201 | parser.add_argument( |
| 202 | "--sample_frequency", | ||
| 203 | type=int, | ||
| 204 | default=100, | ||
| 205 | help="How often to save a checkpoint and sample image", | ||
| 206 | ) | ||
| 207 | parser.add_argument( | ||
| 202 | "--sample_image_size", | 208 | "--sample_image_size", |
| 203 | type=int, | 209 | type=int, |
| 204 | default=512, | 210 | default=512, |
| @@ -366,20 +372,20 @@ class Checkpointer: | |||
| 366 | generator=generator, | 372 | generator=generator, |
| 367 | ) | 373 | ) |
| 368 | 374 | ||
| 369 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 375 | with torch.inference_mode(): |
| 370 | all_samples = [] | 376 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 371 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 377 | all_samples = [] |
| 372 | file_path.parent.mkdir(parents=True, exist_ok=True) | 378 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 379 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 373 | 380 | ||
| 374 | data_enum = enumerate(data) | 381 | data_enum = enumerate(data) |
| 375 | 382 | ||
| 376 | for i in range(self.sample_batches): | 383 | for i in range(self.sample_batches): |
| 377 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 384 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 378 | prompt = [prompt.format(self.instance_identifier) | 385 | prompt = [prompt.format(self.instance_identifier) |
| 379 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 386 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 380 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 387 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 381 | 388 | ||
| 382 | with self.accelerator.autocast(): | ||
| 383 | samples = pipeline( | 389 | samples = pipeline( |
| 384 | prompt=prompt, | 390 | prompt=prompt, |
| 385 | negative_prompt=nprompt, | 391 | negative_prompt=nprompt, |
| @@ -393,15 +399,15 @@ class Checkpointer: | |||
| 393 | output_type='pil' | 399 | output_type='pil' |
| 394 | )["sample"] | 400 | )["sample"] |
| 395 | 401 | ||
| 396 | all_samples += samples | 402 | all_samples += samples |
| 397 | 403 | ||
| 398 | del samples | 404 | del samples |
| 399 | 405 | ||
| 400 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 406 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 401 | image_grid.save(file_path) | 407 | image_grid.save(file_path) |
| 402 | 408 | ||
| 403 | del all_samples | 409 | del all_samples |
| 404 | del image_grid | 410 | del image_grid |
| 405 | 411 | ||
| 406 | del unwrapped | 412 | del unwrapped |
| 407 | del scheduler | 413 | del scheduler |
| @@ -538,7 +544,7 @@ def main(): | |||
| 538 | datamodule.setup() | 544 | datamodule.setup() |
| 539 | 545 | ||
| 540 | if args.num_class_images != 0: | 546 | if args.num_class_images != 0: |
| 541 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 547 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 542 | 548 | ||
| 543 | if len(missing_data) != 0: | 549 | if len(missing_data) != 0: |
| 544 | batched_data = [missing_data[i:i+args.sample_batch_size] | 550 | batched_data = [missing_data[i:i+args.sample_batch_size] |
| @@ -558,20 +564,20 @@ def main(): | |||
| 558 | pipeline.enable_attention_slicing() | 564 | pipeline.enable_attention_slicing() |
| 559 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 565 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 560 | 566 | ||
| 561 | for batch in batched_data: | 567 | with torch.inference_mode(): |
| 562 | image_name = [p[1] for p in batch] | 568 | for batch in batched_data: |
| 563 | prompt = [p[2].format(args.class_identifier) for p in batch] | 569 | image_name = [p.class_image_path for p in batch] |
| 564 | nprompt = [p[3] for p in batch] | 570 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
| 571 | nprompt = [p.nprompt for p in batch] | ||
| 565 | 572 | ||
| 566 | with accelerator.autocast(): | ||
| 567 | images = pipeline( | 573 | images = pipeline( |
| 568 | prompt=prompt, | 574 | prompt=prompt, |
| 569 | negative_prompt=nprompt, | 575 | negative_prompt=nprompt, |
| 570 | num_inference_steps=args.sample_steps | 576 | num_inference_steps=args.sample_steps |
| 571 | ).images | 577 | ).images |
| 572 | 578 | ||
| 573 | for i, image in enumerate(images): | 579 | for i, image in enumerate(images): |
| 574 | image.save(image_name[i]) | 580 | image.save(image_name[i]) |
| 575 | 581 | ||
| 576 | del pipeline | 582 | del pipeline |
| 577 | 583 | ||
| @@ -677,6 +683,8 @@ def main(): | |||
| 677 | unet.train() | 683 | unet.train() |
| 678 | train_loss = 0.0 | 684 | train_loss = 0.0 |
| 679 | 685 | ||
| 686 | sample_checkpoint = False | ||
| 687 | |||
| 680 | for step, batch in enumerate(train_dataloader): | 688 | for step, batch in enumerate(train_dataloader): |
| 681 | with accelerator.accumulate(unet): | 689 | with accelerator.accumulate(unet): |
| 682 | # Convert images to latent space | 690 | # Convert images to latent space |
| @@ -737,6 +745,9 @@ def main(): | |||
| 737 | 745 | ||
| 738 | global_step += 1 | 746 | global_step += 1 |
| 739 | 747 | ||
| 748 | if global_step % args.sample_frequency == 0: | ||
| 749 | sample_checkpoint = True | ||
| 750 | |||
| 740 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 751 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 741 | local_progress_bar.set_postfix(**logs) | 752 | local_progress_bar.set_postfix(**logs) |
| 742 | 753 | ||
| @@ -783,7 +794,11 @@ def main(): | |||
| 783 | 794 | ||
| 784 | val_loss /= len(val_dataloader) | 795 | val_loss /= len(val_dataloader) |
| 785 | 796 | ||
| 786 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 797 | accelerator.log({ |
| 798 | "train/loss": train_loss, | ||
| 799 | "val/loss": val_loss, | ||
| 800 | "lr": lr_scheduler.get_last_lr()[0] | ||
| 801 | }, step=global_step) | ||
| 787 | 802 | ||
| 788 | local_progress_bar.clear() | 803 | local_progress_bar.clear() |
| 789 | global_progress_bar.clear() | 804 | global_progress_bar.clear() |
| @@ -792,7 +807,7 @@ def main(): | |||
| 792 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 807 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 793 | min_val_loss = val_loss | 808 | min_val_loss = val_loss |
| 794 | 809 | ||
| 795 | if accelerator.is_main_process: | 810 | if sample_checkpoint and accelerator.is_main_process: |
| 796 | checkpointer.save_samples( | 811 | checkpointer.save_samples( |
| 797 | global_step, | 812 | global_step, |
| 798 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 813 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
diff --git a/environment.yaml b/environment.yaml index c9f498e..5ecc5a8 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -32,6 +32,6 @@ dependencies: | |||
| 32 | - test-tube>=0.7.5 | 32 | - test-tube>=0.7.5 |
| 33 | - torch-fidelity==0.3.0 | 33 | - torch-fidelity==0.3.0 |
| 34 | - torchmetrics==0.9.3 | 34 | - torchmetrics==0.9.3 |
| 35 | - transformers==4.22.1 | 35 | - transformers==4.22.2 |
| 36 | - triton==2.0.0.dev20220924 | 36 | - triton==2.0.0.dev20220924 |
| 37 | - xformers==0.0.13 | 37 | - xformers==0.0.13 |
| @@ -5,12 +5,11 @@ import sys | |||
| 5 | import shlex | 5 | import shlex |
| 6 | import cmd | 6 | import cmd |
| 7 | from pathlib import Path | 7 | from pathlib import Path |
| 8 | from torch import autocast | ||
| 9 | import torch | 8 | import torch |
| 10 | import json | 9 | import json |
| 11 | from PIL import Image | 10 | from PIL import Image |
| 12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from slugify import slugify | 13 | from slugify import slugify |
| 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 16 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
| @@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 22 | default_args = { | 21 | default_args = { |
| 23 | "model": None, | 22 | "model": None, |
| 24 | "scheduler": "euler_a", | 23 | "scheduler": "euler_a", |
| 25 | "precision": "bf16", | 24 | "precision": "fp16", |
| 26 | "embeddings_dir": "embeddings", | 25 | "embeddings_dir": "embeddings", |
| 27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", |
| 28 | "config": None, | 27 | "config": None, |
| @@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args): | |||
| 260 | else: | 259 | else: |
| 261 | init_image = None | 260 | init_image = None |
| 262 | 261 | ||
| 263 | with autocast("cuda"): | 262 | with torch.autocast("cuda"), torch.inference_mode(): |
| 264 | for i in range(args.batch_num): | 263 | for i in range(args.batch_num): |
| 265 | pipeline.set_progress_bar_config( | 264 | pipeline.set_progress_bar_config( |
| 266 | desc=f"Batch {i + 1} of {args.batch_num}", | 265 | desc=f"Batch {i + 1} of {args.batch_num}", |
| @@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd): | |||
| 313 | args = run_parser(self.parser, default_cmds, elements) | 312 | args = run_parser(self.parser, default_cmds, elements) |
| 314 | except SystemExit: | 313 | except SystemExit: |
| 315 | self.parser.print_help() | 314 | self.parser.print_help() |
| 315 | except Exception as e: | ||
| 316 | print(e) | ||
| 317 | return | ||
| 316 | 318 | ||
| 317 | if len(args.prompt) == 0: | 319 | if len(args.prompt) == 0: |
| 318 | print('Try again with a prompt!') | 320 | print('Try again with a prompt!') |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a198cf6..bfecd1c 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | 234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) |
| 235 | elif isinstance(latents, PIL.Image.Image): | 235 | elif isinstance(latents, PIL.Image.Image): |
| 236 | latents = preprocess(latents, width, height) | 236 | latents = preprocess(latents, width, height) |
| 237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 237 | latents = latents.to(device=self.device, dtype=latents_dtype) |
| 238 | latent_dist = self.vae.encode(latents).latent_dist | ||
| 238 | latents = latent_dist.sample(generator=generator) | 239 | latents = latent_dist.sample(generator=generator) |
| 239 | latents = 0.18215 * latents | 240 | latents = 0.18215 * latents |
| 240 | 241 | ||
| @@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 249 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 250 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) |
| 250 | 251 | ||
| 251 | # add noise to latents using the timesteps | 252 | # add noise to latents using the timesteps |
| 252 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | 253 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
| 253 | latents = self.scheduler.add_noise(latents, noise, timesteps) | 254 | latents = self.scheduler.add_noise(latents, noise, timesteps) |
| 254 | else: | 255 | else: |
| 255 | if latents.shape != latents_shape: | 256 | if latents.shape != latents_shape: |
diff --git a/textual_inversion.py b/textual_inversion.py index 4f2de9e..09871d4 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -4,6 +4,7 @@ import math | |||
| 4 | import os | 4 | import os |
| 5 | import datetime | 5 | import datetime |
| 6 | import logging | 6 | import logging |
| 7 | import json | ||
| 7 | from pathlib import Path | 8 | from pathlib import Path |
| 8 | 9 | ||
| 9 | import numpy as np | 10 | import numpy as np |
| @@ -22,8 +23,6 @@ from tqdm.auto import tqdm | |||
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 24 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | import json | ||
| 26 | import os | ||
| 27 | 26 | ||
| 28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 29 | 28 | ||
| @@ -70,7 +69,7 @@ def parse_args(): | |||
| 70 | parser.add_argument( | 69 | parser.add_argument( |
| 71 | "--num_class_images", | 70 | "--num_class_images", |
| 72 | type=int, | 71 | type=int, |
| 73 | default=4, | 72 | default=200, |
| 74 | help="How many class images to generate per training image." | 73 | help="How many class images to generate per training image." |
| 75 | ) | 74 | ) |
| 76 | parser.add_argument( | 75 | parser.add_argument( |
| @@ -141,7 +140,7 @@ def parse_args(): | |||
| 141 | parser.add_argument( | 140 | parser.add_argument( |
| 142 | "--lr_scheduler", | 141 | "--lr_scheduler", |
| 143 | type=str, | 142 | type=str, |
| 144 | default="constant", | 143 | default="linear", |
| 145 | help=( | 144 | help=( |
| 146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| 147 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
| @@ -402,20 +401,20 @@ class Checkpointer: | |||
| 402 | generator=generator, | 401 | generator=generator, |
| 403 | ) | 402 | ) |
| 404 | 403 | ||
| 405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 404 | with torch.inference_mode(): |
| 406 | all_samples = [] | 405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 406 | all_samples = [] |
| 408 | file_path.parent.mkdir(parents=True, exist_ok=True) | 407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 409 | 409 | ||
| 410 | data_enum = enumerate(data) | 410 | data_enum = enumerate(data) |
| 411 | 411 | ||
| 412 | for i in range(self.sample_batches): | 412 | for i in range(self.sample_batches): |
| 413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 414 | prompt = [prompt.format(self.placeholder_token) | 414 | prompt = [prompt.format(self.placeholder_token) |
| 415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
| 417 | 417 | ||
| 418 | with self.accelerator.autocast(): | ||
| 419 | samples = pipeline( | 418 | samples = pipeline( |
| 420 | prompt=prompt, | 419 | prompt=prompt, |
| 421 | negative_prompt=nprompt, | 420 | negative_prompt=nprompt, |
| @@ -429,15 +428,15 @@ class Checkpointer: | |||
| 429 | output_type='pil' | 428 | output_type='pil' |
| 430 | )["sample"] | 429 | )["sample"] |
| 431 | 430 | ||
| 432 | all_samples += samples | 431 | all_samples += samples |
| 433 | 432 | ||
| 434 | del samples | 433 | del samples |
| 435 | 434 | ||
| 436 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 435 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 437 | image_grid.save(file_path) | 436 | image_grid.save(file_path) |
| 438 | 437 | ||
| 439 | del all_samples | 438 | del all_samples |
| 440 | del image_grid | 439 | del image_grid |
| 441 | 440 | ||
| 442 | del unwrapped | 441 | del unwrapped |
| 443 | del scheduler | 442 | del scheduler |
| @@ -623,7 +622,7 @@ def main(): | |||
| 623 | datamodule.setup() | 622 | datamodule.setup() |
| 624 | 623 | ||
| 625 | if args.num_class_images != 0: | 624 | if args.num_class_images != 0: |
| 626 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 625 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 627 | 626 | ||
| 628 | if len(missing_data) != 0: | 627 | if len(missing_data) != 0: |
| 629 | batched_data = [missing_data[i:i+args.sample_batch_size] | 628 | batched_data = [missing_data[i:i+args.sample_batch_size] |
| @@ -643,20 +642,20 @@ def main(): | |||
| 643 | pipeline.enable_attention_slicing() | 642 | pipeline.enable_attention_slicing() |
| 644 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 643 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 645 | 644 | ||
| 646 | for batch in batched_data: | 645 | with torch.inference_mode(): |
| 647 | image_name = [p[1] for p in batch] | 646 | for batch in batched_data: |
| 648 | prompt = [p[2].format(args.initializer_token) for p in batch] | 647 | image_name = [p.class_image_path for p in batch] |
| 649 | nprompt = [p[3] for p in batch] | 648 | prompt = [p.prompt.format(args.initializer_token) for p in batch] |
| 649 | nprompt = [p.nprompt for p in batch] | ||
| 650 | 650 | ||
| 651 | with accelerator.autocast(): | ||
| 652 | images = pipeline( | 651 | images = pipeline( |
| 653 | prompt=prompt, | 652 | prompt=prompt, |
| 654 | negative_prompt=nprompt, | 653 | negative_prompt=nprompt, |
| 655 | num_inference_steps=args.sample_steps | 654 | num_inference_steps=args.sample_steps |
| 656 | ).images | 655 | ).images |
| 657 | 656 | ||
| 658 | for i, image in enumerate(images): | 657 | for i, image in enumerate(images): |
| 659 | image.save(image_name[i]) | 658 | image.save(image_name[i]) |
| 660 | 659 | ||
| 661 | del pipeline | 660 | del pipeline |
| 662 | 661 | ||
