diff options
-rw-r--r-- | data/csv.py | 162 | ||||
-rw-r--r-- | dreambooth.py | 75 | ||||
-rw-r--r-- | environment.yaml | 2 | ||||
-rw-r--r-- | infer.py | 12 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 5 | ||||
-rw-r--r-- | textual_inversion.py | 57 |
6 files changed, 169 insertions, 144 deletions
diff --git a/data/csv.py b/data/csv.py index dcaf7d3..8637ac1 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -1,27 +1,38 @@ | |||
1 | import math | ||
1 | import pandas as pd | 2 | import pandas as pd |
2 | from pathlib import Path | 3 | from pathlib import Path |
3 | import pytorch_lightning as pl | 4 | import pytorch_lightning as pl |
4 | from PIL import Image | 5 | from PIL import Image |
5 | from torch.utils.data import Dataset, DataLoader, random_split | 6 | from torch.utils.data import Dataset, DataLoader, random_split |
6 | from torchvision import transforms | 7 | from torchvision import transforms |
8 | from typing import NamedTuple, List | ||
9 | |||
10 | |||
11 | class CSVDataItem(NamedTuple): | ||
12 | instance_image_path: Path | ||
13 | class_image_path: Path | ||
14 | prompt: str | ||
15 | nprompt: str | ||
7 | 16 | ||
8 | 17 | ||
9 | class CSVDataModule(pl.LightningDataModule): | 18 | class CSVDataModule(pl.LightningDataModule): |
10 | def __init__(self, | 19 | def __init__( |
11 | batch_size, | 20 | self, |
12 | data_file, | 21 | batch_size, |
13 | tokenizer, | 22 | data_file, |
14 | instance_identifier, | 23 | tokenizer, |
15 | class_identifier=None, | 24 | instance_identifier, |
16 | class_subdir="db_cls", | 25 | class_identifier=None, |
17 | num_class_images=2, | 26 | class_subdir="db_cls", |
18 | size=512, | 27 | num_class_images=100, |
19 | repeats=100, | 28 | size=512, |
20 | interpolation="bicubic", | 29 | repeats=100, |
21 | center_crop=False, | 30 | interpolation="bicubic", |
22 | valid_set_size=None, | 31 | center_crop=False, |
23 | generator=None, | 32 | valid_set_size=None, |
24 | collate_fn=None): | 33 | generator=None, |
34 | collate_fn=None | ||
35 | ): | ||
25 | super().__init__() | 36 | super().__init__() |
26 | 37 | ||
27 | self.data_file = Path(data_file) | 38 | self.data_file = Path(data_file) |
@@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule): | |||
46 | self.collate_fn = collate_fn | 57 | self.collate_fn = collate_fn |
47 | self.batch_size = batch_size | 58 | self.batch_size = batch_size |
48 | 59 | ||
60 | def prepare_subdata(self, data, num_class_images=1): | ||
61 | image_multiplier = max(math.ceil(num_class_images / len(data)), 1) | ||
62 | |||
63 | return [ | ||
64 | CSVDataItem( | ||
65 | self.data_root.joinpath(item.image), | ||
66 | self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), | ||
67 | item.prompt, | ||
68 | item.nprompt if "nprompt" in item else "" | ||
69 | ) | ||
70 | for item in data | ||
71 | if "skip" not in item or item.skip != "x" | ||
72 | for i in range(image_multiplier) | ||
73 | ] | ||
74 | |||
49 | def prepare_data(self): | 75 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 76 | metadata = pd.read_csv(self.data_file) |
51 | instance_image_paths = [ | 77 | metadata = list(metadata.itertuples()) |
52 | self.data_root.joinpath(f) | 78 | num_images = len(metadata) |
53 | for f in metadata['image'].values | ||
54 | for i in range(self.num_class_images) | ||
55 | ] | ||
56 | class_image_paths = [ | ||
57 | self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}") | ||
58 | for f in metadata['image'].values | ||
59 | for i in range(self.num_class_images) | ||
60 | ] | ||
61 | prompts = [ | ||
62 | prompt | ||
63 | for prompt in metadata['prompt'].values | ||
64 | for i in range(self.num_class_images) | ||
65 | ] | ||
66 | nprompts = [ | ||
67 | nprompt | ||
68 | for nprompt in metadata['nprompt'].values | ||
69 | for i in range(self.num_class_images) | ||
70 | ] if 'nprompt' in metadata else [""] * len(instance_image_paths) | ||
71 | skips = [ | ||
72 | skip | ||
73 | for skip in metadata['skip'].values | ||
74 | for i in range(self.num_class_images) | ||
75 | ] if 'skip' in metadata else [""] * len(instance_image_paths) | ||
76 | self.data = [ | ||
77 | (i, c, p, n) | ||
78 | for i, c, p, n, s | ||
79 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
80 | if s != "x" | ||
81 | ] | ||
82 | 79 | ||
83 | def setup(self, stage=None): | 80 | valid_set_size = int(num_images * 0.2) |
84 | valid_set_size = int(len(self.data) * 0.2) | ||
85 | if self.valid_set_size: | 81 | if self.valid_set_size: |
86 | valid_set_size = min(valid_set_size, self.valid_set_size) | 82 | valid_set_size = min(valid_set_size, self.valid_set_size) |
87 | valid_set_size = max(valid_set_size, 1) | 83 | valid_set_size = max(valid_set_size, 1) |
88 | train_set_size = len(self.data) - valid_set_size | 84 | train_set_size = num_images - valid_set_size |
89 | 85 | ||
90 | self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) | 86 | data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) |
91 | 87 | ||
92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, | 88 | self.data_train = self.prepare_subdata(data_train, self.num_class_images) |
89 | self.data_val = self.prepare_subdata(data_val) | ||
90 | |||
91 | def setup(self, stage=None): | ||
92 | train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size, | ||
93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, | 93 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
94 | num_class_images=self.num_class_images, | 94 | num_class_images=self.num_class_images, |
95 | size=self.size, interpolation=self.interpolation, | 95 | size=self.size, interpolation=self.interpolation, |
96 | center_crop=self.center_crop, repeats=self.repeats) | 96 | center_crop=self.center_crop, repeats=self.repeats) |
97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, | 97 | val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size, |
98 | instance_identifier=self.instance_identifier, | 98 | instance_identifier=self.instance_identifier, |
99 | size=self.size, interpolation=self.interpolation, | 99 | size=self.size, interpolation=self.interpolation, |
100 | center_crop=self.center_crop, repeats=self.repeats) | 100 | center_crop=self.center_crop, repeats=self.repeats) |
101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, | 101 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 102 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, | 103 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
104 | pin_memory=True, collate_fn=self.collate_fn) | 104 | pin_memory=True, collate_fn=self.collate_fn) |
105 | 105 | ||
106 | def train_dataloader(self): | 106 | def train_dataloader(self): |
@@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule): | |||
111 | 111 | ||
112 | 112 | ||
113 | class CSVDataset(Dataset): | 113 | class CSVDataset(Dataset): |
114 | def __init__(self, | 114 | def __init__( |
115 | data, | 115 | self, |
116 | tokenizer, | 116 | data: List[CSVDataItem], |
117 | instance_identifier, | 117 | tokenizer, |
118 | class_identifier=None, | 118 | instance_identifier, |
119 | num_class_images=2, | 119 | batch_size=1, |
120 | size=512, | 120 | class_identifier=None, |
121 | repeats=1, | 121 | num_class_images=0, |
122 | interpolation="bicubic", | 122 | size=512, |
123 | center_crop=False, | 123 | repeats=1, |
124 | ): | 124 | interpolation="bicubic", |
125 | center_crop=False, | ||
126 | ): | ||
125 | 127 | ||
126 | self.data = data | 128 | self.data = data |
127 | self.tokenizer = tokenizer | 129 | self.tokenizer = tokenizer |
130 | self.batch_size = batch_size | ||
128 | self.instance_identifier = instance_identifier | 131 | self.instance_identifier = instance_identifier |
129 | self.class_identifier = class_identifier | 132 | self.class_identifier = class_identifier |
130 | self.num_class_images = num_class_images | 133 | self.num_class_images = num_class_images |
131 | self.cache = {} | 134 | self.cache = {} |
135 | self.image_cache = {} | ||
132 | 136 | ||
133 | self.num_instance_images = len(self.data) | 137 | self.num_instance_images = len(self.data) |
134 | self._length = self.num_instance_images * repeats | 138 | self._length = self.num_instance_images * repeats |
@@ -149,46 +153,50 @@ class CSVDataset(Dataset): | |||
149 | ) | 153 | ) |
150 | 154 | ||
151 | def __len__(self): | 155 | def __len__(self): |
152 | return self._length | 156 | return math.ceil(self._length / self.batch_size) * self.batch_size |
153 | 157 | ||
154 | def get_example(self, i): | 158 | def get_example(self, i): |
155 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 159 | item = self.data[i % self.num_instance_images] |
156 | cache_key = f"{instance_image_path}_{class_image_path}" | 160 | cache_key = f"{item.instance_image_path}_{item.class_image_path}" |
157 | 161 | ||
158 | if cache_key in self.cache: | 162 | if cache_key in self.cache: |
159 | return self.cache[cache_key] | 163 | return self.cache[cache_key] |
160 | 164 | ||
161 | example = {} | 165 | example = {} |
162 | 166 | ||
163 | example["prompts"] = prompt | 167 | example["prompts"] = item.prompt |
164 | example["nprompts"] = nprompt | 168 | example["nprompts"] = item.nprompt |
165 | 169 | ||
166 | instance_image = Image.open(instance_image_path) | 170 | if item.instance_image_path in self.image_cache: |
167 | if not instance_image.mode == "RGB": | 171 | instance_image = self.image_cache[item.instance_image_path] |
168 | instance_image = instance_image.convert("RGB") | 172 | else: |
173 | instance_image = Image.open(item.instance_image_path) | ||
174 | if not instance_image.mode == "RGB": | ||
175 | instance_image = instance_image.convert("RGB") | ||
176 | self.image_cache[item.instance_image_path] = instance_image | ||
169 | 177 | ||
170 | example["instance_images"] = instance_image | 178 | example["instance_images"] = instance_image |
171 | example["instance_prompt_ids"] = self.tokenizer( | 179 | example["instance_prompt_ids"] = self.tokenizer( |
172 | prompt.format(self.instance_identifier), | 180 | item.prompt.format(self.instance_identifier), |
173 | padding="do_not_pad", | 181 | padding="do_not_pad", |
174 | truncation=True, | 182 | truncation=True, |
175 | max_length=self.tokenizer.model_max_length, | 183 | max_length=self.tokenizer.model_max_length, |
176 | ).input_ids | 184 | ).input_ids |
177 | 185 | ||
178 | if self.num_class_images != 0: | 186 | if self.num_class_images != 0: |
179 | class_image = Image.open(class_image_path) | 187 | class_image = Image.open(item.class_image_path) |
180 | if not class_image.mode == "RGB": | 188 | if not class_image.mode == "RGB": |
181 | class_image = class_image.convert("RGB") | 189 | class_image = class_image.convert("RGB") |
182 | 190 | ||
183 | example["class_images"] = class_image | 191 | example["class_images"] = class_image |
184 | example["class_prompt_ids"] = self.tokenizer( | 192 | example["class_prompt_ids"] = self.tokenizer( |
185 | prompt.format(self.class_identifier), | 193 | item.prompt.format(self.class_identifier), |
186 | padding="do_not_pad", | 194 | padding="do_not_pad", |
187 | truncation=True, | 195 | truncation=True, |
188 | max_length=self.tokenizer.model_max_length, | 196 | max_length=self.tokenizer.model_max_length, |
189 | ).input_ids | 197 | ).input_ids |
190 | 198 | ||
191 | self.cache[instance_image_path] = example | 199 | self.cache[item.instance_image_path] = example |
192 | return example | 200 | return example |
193 | 201 | ||
194 | def __getitem__(self, i): | 202 | def __getitem__(self, i): |
diff --git a/dreambooth.py b/dreambooth.py index a26bea7..7b61c45 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -3,6 +3,7 @@ import math | |||
3 | import os | 3 | import os |
4 | import datetime | 4 | import datetime |
5 | import logging | 5 | import logging |
6 | import json | ||
6 | from pathlib import Path | 7 | from pathlib import Path |
7 | 8 | ||
8 | import numpy as np | 9 | import numpy as np |
@@ -21,7 +22,6 @@ from tqdm.auto import tqdm | |||
21 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 23 | from slugify import slugify |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | import json | ||
25 | 25 | ||
26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
27 | 27 | ||
@@ -68,7 +68,7 @@ def parse_args(): | |||
68 | parser.add_argument( | 68 | parser.add_argument( |
69 | "--num_class_images", | 69 | "--num_class_images", |
70 | type=int, | 70 | type=int, |
71 | default=4, | 71 | default=200, |
72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." |
73 | ) | 73 | ) |
74 | parser.add_argument( | 74 | parser.add_argument( |
@@ -140,7 +140,7 @@ def parse_args(): | |||
140 | parser.add_argument( | 140 | parser.add_argument( |
141 | "--lr_scheduler", | 141 | "--lr_scheduler", |
142 | type=str, | 142 | type=str, |
143 | default="constant", | 143 | default="linear", |
144 | help=( | 144 | help=( |
145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
146 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
@@ -199,6 +199,12 @@ def parse_args(): | |||
199 | help="For distributed training: local_rank" | 199 | help="For distributed training: local_rank" |
200 | ) | 200 | ) |
201 | parser.add_argument( | 201 | parser.add_argument( |
202 | "--sample_frequency", | ||
203 | type=int, | ||
204 | default=100, | ||
205 | help="How often to save a checkpoint and sample image", | ||
206 | ) | ||
207 | parser.add_argument( | ||
202 | "--sample_image_size", | 208 | "--sample_image_size", |
203 | type=int, | 209 | type=int, |
204 | default=512, | 210 | default=512, |
@@ -366,20 +372,20 @@ class Checkpointer: | |||
366 | generator=generator, | 372 | generator=generator, |
367 | ) | 373 | ) |
368 | 374 | ||
369 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 375 | with torch.inference_mode(): |
370 | all_samples = [] | 376 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
371 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 377 | all_samples = [] |
372 | file_path.parent.mkdir(parents=True, exist_ok=True) | 378 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
379 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
373 | 380 | ||
374 | data_enum = enumerate(data) | 381 | data_enum = enumerate(data) |
375 | 382 | ||
376 | for i in range(self.sample_batches): | 383 | for i in range(self.sample_batches): |
377 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 384 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
378 | prompt = [prompt.format(self.instance_identifier) | 385 | prompt = [prompt.format(self.instance_identifier) |
379 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 386 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
380 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 387 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
381 | 388 | ||
382 | with self.accelerator.autocast(): | ||
383 | samples = pipeline( | 389 | samples = pipeline( |
384 | prompt=prompt, | 390 | prompt=prompt, |
385 | negative_prompt=nprompt, | 391 | negative_prompt=nprompt, |
@@ -393,15 +399,15 @@ class Checkpointer: | |||
393 | output_type='pil' | 399 | output_type='pil' |
394 | )["sample"] | 400 | )["sample"] |
395 | 401 | ||
396 | all_samples += samples | 402 | all_samples += samples |
397 | 403 | ||
398 | del samples | 404 | del samples |
399 | 405 | ||
400 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 406 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
401 | image_grid.save(file_path) | 407 | image_grid.save(file_path) |
402 | 408 | ||
403 | del all_samples | 409 | del all_samples |
404 | del image_grid | 410 | del image_grid |
405 | 411 | ||
406 | del unwrapped | 412 | del unwrapped |
407 | del scheduler | 413 | del scheduler |
@@ -538,7 +544,7 @@ def main(): | |||
538 | datamodule.setup() | 544 | datamodule.setup() |
539 | 545 | ||
540 | if args.num_class_images != 0: | 546 | if args.num_class_images != 0: |
541 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 547 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
542 | 548 | ||
543 | if len(missing_data) != 0: | 549 | if len(missing_data) != 0: |
544 | batched_data = [missing_data[i:i+args.sample_batch_size] | 550 | batched_data = [missing_data[i:i+args.sample_batch_size] |
@@ -558,20 +564,20 @@ def main(): | |||
558 | pipeline.enable_attention_slicing() | 564 | pipeline.enable_attention_slicing() |
559 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 565 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
560 | 566 | ||
561 | for batch in batched_data: | 567 | with torch.inference_mode(): |
562 | image_name = [p[1] for p in batch] | 568 | for batch in batched_data: |
563 | prompt = [p[2].format(args.class_identifier) for p in batch] | 569 | image_name = [p.class_image_path for p in batch] |
564 | nprompt = [p[3] for p in batch] | 570 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
571 | nprompt = [p.nprompt for p in batch] | ||
565 | 572 | ||
566 | with accelerator.autocast(): | ||
567 | images = pipeline( | 573 | images = pipeline( |
568 | prompt=prompt, | 574 | prompt=prompt, |
569 | negative_prompt=nprompt, | 575 | negative_prompt=nprompt, |
570 | num_inference_steps=args.sample_steps | 576 | num_inference_steps=args.sample_steps |
571 | ).images | 577 | ).images |
572 | 578 | ||
573 | for i, image in enumerate(images): | 579 | for i, image in enumerate(images): |
574 | image.save(image_name[i]) | 580 | image.save(image_name[i]) |
575 | 581 | ||
576 | del pipeline | 582 | del pipeline |
577 | 583 | ||
@@ -677,6 +683,8 @@ def main(): | |||
677 | unet.train() | 683 | unet.train() |
678 | train_loss = 0.0 | 684 | train_loss = 0.0 |
679 | 685 | ||
686 | sample_checkpoint = False | ||
687 | |||
680 | for step, batch in enumerate(train_dataloader): | 688 | for step, batch in enumerate(train_dataloader): |
681 | with accelerator.accumulate(unet): | 689 | with accelerator.accumulate(unet): |
682 | # Convert images to latent space | 690 | # Convert images to latent space |
@@ -737,6 +745,9 @@ def main(): | |||
737 | 745 | ||
738 | global_step += 1 | 746 | global_step += 1 |
739 | 747 | ||
748 | if global_step % args.sample_frequency == 0: | ||
749 | sample_checkpoint = True | ||
750 | |||
740 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 751 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
741 | local_progress_bar.set_postfix(**logs) | 752 | local_progress_bar.set_postfix(**logs) |
742 | 753 | ||
@@ -783,7 +794,11 @@ def main(): | |||
783 | 794 | ||
784 | val_loss /= len(val_dataloader) | 795 | val_loss /= len(val_dataloader) |
785 | 796 | ||
786 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 797 | accelerator.log({ |
798 | "train/loss": train_loss, | ||
799 | "val/loss": val_loss, | ||
800 | "lr": lr_scheduler.get_last_lr()[0] | ||
801 | }, step=global_step) | ||
787 | 802 | ||
788 | local_progress_bar.clear() | 803 | local_progress_bar.clear() |
789 | global_progress_bar.clear() | 804 | global_progress_bar.clear() |
@@ -792,7 +807,7 @@ def main(): | |||
792 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 807 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
793 | min_val_loss = val_loss | 808 | min_val_loss = val_loss |
794 | 809 | ||
795 | if accelerator.is_main_process: | 810 | if sample_checkpoint and accelerator.is_main_process: |
796 | checkpointer.save_samples( | 811 | checkpointer.save_samples( |
797 | global_step, | 812 | global_step, |
798 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 813 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
diff --git a/environment.yaml b/environment.yaml index c9f498e..5ecc5a8 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -32,6 +32,6 @@ dependencies: | |||
32 | - test-tube>=0.7.5 | 32 | - test-tube>=0.7.5 |
33 | - torch-fidelity==0.3.0 | 33 | - torch-fidelity==0.3.0 |
34 | - torchmetrics==0.9.3 | 34 | - torchmetrics==0.9.3 |
35 | - transformers==4.22.1 | 35 | - transformers==4.22.2 |
36 | - triton==2.0.0.dev20220924 | 36 | - triton==2.0.0.dev20220924 |
37 | - xformers==0.0.13 | 37 | - xformers==0.0.13 |
@@ -5,12 +5,11 @@ import sys | |||
5 | import shlex | 5 | import shlex |
6 | import cmd | 6 | import cmd |
7 | from pathlib import Path | 7 | from pathlib import Path |
8 | from torch import autocast | ||
9 | import torch | 8 | import torch |
10 | import json | 9 | import json |
11 | from PIL import Image | 10 | from PIL import Image |
12 | from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler |
13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor | 12 | from transformers import CLIPTextModel, CLIPTokenizer |
14 | from slugify import slugify | 13 | from slugify import slugify |
15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 14 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
16 | from schedulers.scheduling_euler_a import EulerAScheduler | 15 | from schedulers.scheduling_euler_a import EulerAScheduler |
@@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
22 | default_args = { | 21 | default_args = { |
23 | "model": None, | 22 | "model": None, |
24 | "scheduler": "euler_a", | 23 | "scheduler": "euler_a", |
25 | "precision": "bf16", | 24 | "precision": "fp16", |
26 | "embeddings_dir": "embeddings", | 25 | "embeddings_dir": "embeddings", |
27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", |
28 | "config": None, | 27 | "config": None, |
@@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args): | |||
260 | else: | 259 | else: |
261 | init_image = None | 260 | init_image = None |
262 | 261 | ||
263 | with autocast("cuda"): | 262 | with torch.autocast("cuda"), torch.inference_mode(): |
264 | for i in range(args.batch_num): | 263 | for i in range(args.batch_num): |
265 | pipeline.set_progress_bar_config( | 264 | pipeline.set_progress_bar_config( |
266 | desc=f"Batch {i + 1} of {args.batch_num}", | 265 | desc=f"Batch {i + 1} of {args.batch_num}", |
@@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd): | |||
313 | args = run_parser(self.parser, default_cmds, elements) | 312 | args = run_parser(self.parser, default_cmds, elements) |
314 | except SystemExit: | 313 | except SystemExit: |
315 | self.parser.print_help() | 314 | self.parser.print_help() |
315 | except Exception as e: | ||
316 | print(e) | ||
317 | return | ||
316 | 318 | ||
317 | if len(args.prompt) == 0: | 319 | if len(args.prompt) == 0: |
318 | print('Try again with a prompt!') | 320 | print('Try again with a prompt!') |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a198cf6..bfecd1c 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | 234 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) |
235 | elif isinstance(latents, PIL.Image.Image): | 235 | elif isinstance(latents, PIL.Image.Image): |
236 | latents = preprocess(latents, width, height) | 236 | latents = preprocess(latents, width, height) |
237 | latent_dist = self.vae.encode(latents.to(self.device)).latent_dist | 237 | latents = latents.to(device=self.device, dtype=latents_dtype) |
238 | latent_dist = self.vae.encode(latents).latent_dist | ||
238 | latents = latent_dist.sample(generator=generator) | 239 | latents = latent_dist.sample(generator=generator) |
239 | latents = 0.18215 * latents | 240 | latents = 0.18215 * latents |
240 | 241 | ||
@@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
249 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 250 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) |
250 | 251 | ||
251 | # add noise to latents using the timesteps | 252 | # add noise to latents using the timesteps |
252 | noise = torch.randn(latents.shape, generator=generator, device=self.device) | 253 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) |
253 | latents = self.scheduler.add_noise(latents, noise, timesteps) | 254 | latents = self.scheduler.add_noise(latents, noise, timesteps) |
254 | else: | 255 | else: |
255 | if latents.shape != latents_shape: | 256 | if latents.shape != latents_shape: |
diff --git a/textual_inversion.py b/textual_inversion.py index 4f2de9e..09871d4 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -4,6 +4,7 @@ import math | |||
4 | import os | 4 | import os |
5 | import datetime | 5 | import datetime |
6 | import logging | 6 | import logging |
7 | import json | ||
7 | from pathlib import Path | 8 | from pathlib import Path |
8 | 9 | ||
9 | import numpy as np | 10 | import numpy as np |
@@ -22,8 +23,6 @@ from tqdm.auto import tqdm | |||
22 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 24 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | import json | ||
26 | import os | ||
27 | 26 | ||
28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
29 | 28 | ||
@@ -70,7 +69,7 @@ def parse_args(): | |||
70 | parser.add_argument( | 69 | parser.add_argument( |
71 | "--num_class_images", | 70 | "--num_class_images", |
72 | type=int, | 71 | type=int, |
73 | default=4, | 72 | default=200, |
74 | help="How many class images to generate per training image." | 73 | help="How many class images to generate per training image." |
75 | ) | 74 | ) |
76 | parser.add_argument( | 75 | parser.add_argument( |
@@ -141,7 +140,7 @@ def parse_args(): | |||
141 | parser.add_argument( | 140 | parser.add_argument( |
142 | "--lr_scheduler", | 141 | "--lr_scheduler", |
143 | type=str, | 142 | type=str, |
144 | default="constant", | 143 | default="linear", |
145 | help=( | 144 | help=( |
146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
147 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
@@ -402,20 +401,20 @@ class Checkpointer: | |||
402 | generator=generator, | 401 | generator=generator, |
403 | ) | 402 | ) |
404 | 403 | ||
405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 404 | with torch.inference_mode(): |
406 | all_samples = [] | 405 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 406 | all_samples = [] |
408 | file_path.parent.mkdir(parents=True, exist_ok=True) | 407 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
409 | 409 | ||
410 | data_enum = enumerate(data) | 410 | data_enum = enumerate(data) |
411 | 411 | ||
412 | for i in range(self.sample_batches): | 412 | for i in range(self.sample_batches): |
413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 413 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
414 | prompt = [prompt.format(self.placeholder_token) | 414 | prompt = [prompt.format(self.placeholder_token) |
415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 415 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 416 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
417 | 417 | ||
418 | with self.accelerator.autocast(): | ||
419 | samples = pipeline( | 418 | samples = pipeline( |
420 | prompt=prompt, | 419 | prompt=prompt, |
421 | negative_prompt=nprompt, | 420 | negative_prompt=nprompt, |
@@ -429,15 +428,15 @@ class Checkpointer: | |||
429 | output_type='pil' | 428 | output_type='pil' |
430 | )["sample"] | 429 | )["sample"] |
431 | 430 | ||
432 | all_samples += samples | 431 | all_samples += samples |
433 | 432 | ||
434 | del samples | 433 | del samples |
435 | 434 | ||
436 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 435 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
437 | image_grid.save(file_path) | 436 | image_grid.save(file_path) |
438 | 437 | ||
439 | del all_samples | 438 | del all_samples |
440 | del image_grid | 439 | del image_grid |
441 | 440 | ||
442 | del unwrapped | 441 | del unwrapped |
443 | del scheduler | 442 | del scheduler |
@@ -623,7 +622,7 @@ def main(): | |||
623 | datamodule.setup() | 622 | datamodule.setup() |
624 | 623 | ||
625 | if args.num_class_images != 0: | 624 | if args.num_class_images != 0: |
626 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 625 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
627 | 626 | ||
628 | if len(missing_data) != 0: | 627 | if len(missing_data) != 0: |
629 | batched_data = [missing_data[i:i+args.sample_batch_size] | 628 | batched_data = [missing_data[i:i+args.sample_batch_size] |
@@ -643,20 +642,20 @@ def main(): | |||
643 | pipeline.enable_attention_slicing() | 642 | pipeline.enable_attention_slicing() |
644 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 643 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
645 | 644 | ||
646 | for batch in batched_data: | 645 | with torch.inference_mode(): |
647 | image_name = [p[1] for p in batch] | 646 | for batch in batched_data: |
648 | prompt = [p[2].format(args.initializer_token) for p in batch] | 647 | image_name = [p.class_image_path for p in batch] |
649 | nprompt = [p[3] for p in batch] | 648 | prompt = [p.prompt.format(args.initializer_token) for p in batch] |
649 | nprompt = [p.nprompt for p in batch] | ||
650 | 650 | ||
651 | with accelerator.autocast(): | ||
652 | images = pipeline( | 651 | images = pipeline( |
653 | prompt=prompt, | 652 | prompt=prompt, |
654 | negative_prompt=nprompt, | 653 | negative_prompt=nprompt, |
655 | num_inference_steps=args.sample_steps | 654 | num_inference_steps=args.sample_steps |
656 | ).images | 655 | ).images |
657 | 656 | ||
658 | for i, image in enumerate(images): | 657 | for i, image in enumerate(images): |
659 | image.save(image_name[i]) | 658 | image.save(image_name[i]) |
660 | 659 | ||
661 | del pipeline | 660 | del pipeline |
662 | 661 | ||