summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py162
-rw-r--r--dreambooth.py75
-rw-r--r--environment.yaml2
-rw-r--r--infer.py12
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py5
-rw-r--r--textual_inversion.py57
6 files changed, 169 insertions, 144 deletions
diff --git a/data/csv.py b/data/csv.py
index dcaf7d3..8637ac1 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,27 +1,38 @@
1import math
1import pandas as pd 2import pandas as pd
2from pathlib import Path 3from pathlib import Path
3import pytorch_lightning as pl 4import pytorch_lightning as pl
4from PIL import Image 5from PIL import Image
5from torch.utils.data import Dataset, DataLoader, random_split 6from torch.utils.data import Dataset, DataLoader, random_split
6from torchvision import transforms 7from torchvision import transforms
8from typing import NamedTuple, List
9
10
11class CSVDataItem(NamedTuple):
12 instance_image_path: Path
13 class_image_path: Path
14 prompt: str
15 nprompt: str
7 16
8 17
9class CSVDataModule(pl.LightningDataModule): 18class CSVDataModule(pl.LightningDataModule):
10 def __init__(self, 19 def __init__(
11 batch_size, 20 self,
12 data_file, 21 batch_size,
13 tokenizer, 22 data_file,
14 instance_identifier, 23 tokenizer,
15 class_identifier=None, 24 instance_identifier,
16 class_subdir="db_cls", 25 class_identifier=None,
17 num_class_images=2, 26 class_subdir="db_cls",
18 size=512, 27 num_class_images=100,
19 repeats=100, 28 size=512,
20 interpolation="bicubic", 29 repeats=100,
21 center_crop=False, 30 interpolation="bicubic",
22 valid_set_size=None, 31 center_crop=False,
23 generator=None, 32 valid_set_size=None,
24 collate_fn=None): 33 generator=None,
34 collate_fn=None
35 ):
25 super().__init__() 36 super().__init__()
26 37
27 self.data_file = Path(data_file) 38 self.data_file = Path(data_file)
@@ -46,61 +57,50 @@ class CSVDataModule(pl.LightningDataModule):
46 self.collate_fn = collate_fn 57 self.collate_fn = collate_fn
47 self.batch_size = batch_size 58 self.batch_size = batch_size
48 59
60 def prepare_subdata(self, data, num_class_images=1):
61 image_multiplier = max(math.ceil(num_class_images / len(data)), 1)
62
63 return [
64 CSVDataItem(
65 self.data_root.joinpath(item.image),
66 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"),
67 item.prompt,
68 item.nprompt if "nprompt" in item else ""
69 )
70 for item in data
71 if "skip" not in item or item.skip != "x"
72 for i in range(image_multiplier)
73 ]
74
49 def prepare_data(self): 75 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 76 metadata = pd.read_csv(self.data_file)
51 instance_image_paths = [ 77 metadata = list(metadata.itertuples())
52 self.data_root.joinpath(f) 78 num_images = len(metadata)
53 for f in metadata['image'].values
54 for i in range(self.num_class_images)
55 ]
56 class_image_paths = [
57 self.class_root.joinpath(f"{Path(f).stem}_{i}_{Path(f).suffix}")
58 for f in metadata['image'].values
59 for i in range(self.num_class_images)
60 ]
61 prompts = [
62 prompt
63 for prompt in metadata['prompt'].values
64 for i in range(self.num_class_images)
65 ]
66 nprompts = [
67 nprompt
68 for nprompt in metadata['nprompt'].values
69 for i in range(self.num_class_images)
70 ] if 'nprompt' in metadata else [""] * len(instance_image_paths)
71 skips = [
72 skip
73 for skip in metadata['skip'].values
74 for i in range(self.num_class_images)
75 ] if 'skip' in metadata else [""] * len(instance_image_paths)
76 self.data = [
77 (i, c, p, n)
78 for i, c, p, n, s
79 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
80 if s != "x"
81 ]
82 79
83 def setup(self, stage=None): 80 valid_set_size = int(num_images * 0.2)
84 valid_set_size = int(len(self.data) * 0.2)
85 if self.valid_set_size: 81 if self.valid_set_size:
86 valid_set_size = min(valid_set_size, self.valid_set_size) 82 valid_set_size = min(valid_set_size, self.valid_set_size)
87 valid_set_size = max(valid_set_size, 1) 83 valid_set_size = max(valid_set_size, 1)
88 train_set_size = len(self.data) - valid_set_size 84 train_set_size = num_images - valid_set_size
89 85
90 self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) 86 data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator)
91 87
92 train_dataset = CSVDataset(self.data_train, self.tokenizer, 88 self.data_train = self.prepare_subdata(data_train, self.num_class_images)
89 self.data_val = self.prepare_subdata(data_val)
90
91 def setup(self, stage=None):
92 train_dataset = CSVDataset(self.data_train, self.tokenizer, batch_size=self.batch_size,
93 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 93 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
94 num_class_images=self.num_class_images, 94 num_class_images=self.num_class_images,
95 size=self.size, interpolation=self.interpolation, 95 size=self.size, interpolation=self.interpolation,
96 center_crop=self.center_crop, repeats=self.repeats) 96 center_crop=self.center_crop, repeats=self.repeats)
97 val_dataset = CSVDataset(self.data_val, self.tokenizer, 97 val_dataset = CSVDataset(self.data_val, self.tokenizer, batch_size=self.batch_size,
98 instance_identifier=self.instance_identifier, 98 instance_identifier=self.instance_identifier,
99 size=self.size, interpolation=self.interpolation, 99 size=self.size, interpolation=self.interpolation,
100 center_crop=self.center_crop, repeats=self.repeats) 100 center_crop=self.center_crop, repeats=self.repeats)
101 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, 101 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
102 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 102 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
103 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, 103 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
104 pin_memory=True, collate_fn=self.collate_fn) 104 pin_memory=True, collate_fn=self.collate_fn)
105 105
106 def train_dataloader(self): 106 def train_dataloader(self):
@@ -111,24 +111,28 @@ class CSVDataModule(pl.LightningDataModule):
111 111
112 112
113class CSVDataset(Dataset): 113class CSVDataset(Dataset):
114 def __init__(self, 114 def __init__(
115 data, 115 self,
116 tokenizer, 116 data: List[CSVDataItem],
117 instance_identifier, 117 tokenizer,
118 class_identifier=None, 118 instance_identifier,
119 num_class_images=2, 119 batch_size=1,
120 size=512, 120 class_identifier=None,
121 repeats=1, 121 num_class_images=0,
122 interpolation="bicubic", 122 size=512,
123 center_crop=False, 123 repeats=1,
124 ): 124 interpolation="bicubic",
125 center_crop=False,
126 ):
125 127
126 self.data = data 128 self.data = data
127 self.tokenizer = tokenizer 129 self.tokenizer = tokenizer
130 self.batch_size = batch_size
128 self.instance_identifier = instance_identifier 131 self.instance_identifier = instance_identifier
129 self.class_identifier = class_identifier 132 self.class_identifier = class_identifier
130 self.num_class_images = num_class_images 133 self.num_class_images = num_class_images
131 self.cache = {} 134 self.cache = {}
135 self.image_cache = {}
132 136
133 self.num_instance_images = len(self.data) 137 self.num_instance_images = len(self.data)
134 self._length = self.num_instance_images * repeats 138 self._length = self.num_instance_images * repeats
@@ -149,46 +153,50 @@ class CSVDataset(Dataset):
149 ) 153 )
150 154
151 def __len__(self): 155 def __len__(self):
152 return self._length 156 return math.ceil(self._length / self.batch_size) * self.batch_size
153 157
154 def get_example(self, i): 158 def get_example(self, i):
155 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] 159 item = self.data[i % self.num_instance_images]
156 cache_key = f"{instance_image_path}_{class_image_path}" 160 cache_key = f"{item.instance_image_path}_{item.class_image_path}"
157 161
158 if cache_key in self.cache: 162 if cache_key in self.cache:
159 return self.cache[cache_key] 163 return self.cache[cache_key]
160 164
161 example = {} 165 example = {}
162 166
163 example["prompts"] = prompt 167 example["prompts"] = item.prompt
164 example["nprompts"] = nprompt 168 example["nprompts"] = item.nprompt
165 169
166 instance_image = Image.open(instance_image_path) 170 if item.instance_image_path in self.image_cache:
167 if not instance_image.mode == "RGB": 171 instance_image = self.image_cache[item.instance_image_path]
168 instance_image = instance_image.convert("RGB") 172 else:
173 instance_image = Image.open(item.instance_image_path)
174 if not instance_image.mode == "RGB":
175 instance_image = instance_image.convert("RGB")
176 self.image_cache[item.instance_image_path] = instance_image
169 177
170 example["instance_images"] = instance_image 178 example["instance_images"] = instance_image
171 example["instance_prompt_ids"] = self.tokenizer( 179 example["instance_prompt_ids"] = self.tokenizer(
172 prompt.format(self.instance_identifier), 180 item.prompt.format(self.instance_identifier),
173 padding="do_not_pad", 181 padding="do_not_pad",
174 truncation=True, 182 truncation=True,
175 max_length=self.tokenizer.model_max_length, 183 max_length=self.tokenizer.model_max_length,
176 ).input_ids 184 ).input_ids
177 185
178 if self.num_class_images != 0: 186 if self.num_class_images != 0:
179 class_image = Image.open(class_image_path) 187 class_image = Image.open(item.class_image_path)
180 if not class_image.mode == "RGB": 188 if not class_image.mode == "RGB":
181 class_image = class_image.convert("RGB") 189 class_image = class_image.convert("RGB")
182 190
183 example["class_images"] = class_image 191 example["class_images"] = class_image
184 example["class_prompt_ids"] = self.tokenizer( 192 example["class_prompt_ids"] = self.tokenizer(
185 prompt.format(self.class_identifier), 193 item.prompt.format(self.class_identifier),
186 padding="do_not_pad", 194 padding="do_not_pad",
187 truncation=True, 195 truncation=True,
188 max_length=self.tokenizer.model_max_length, 196 max_length=self.tokenizer.model_max_length,
189 ).input_ids 197 ).input_ids
190 198
191 self.cache[instance_image_path] = example 199 self.cache[item.instance_image_path] = example
192 return example 200 return example
193 201
194 def __getitem__(self, i): 202 def __getitem__(self, i):
diff --git a/dreambooth.py b/dreambooth.py
index a26bea7..7b61c45 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -3,6 +3,7 @@ import math
3import os 3import os
4import datetime 4import datetime
5import logging 5import logging
6import json
6from pathlib import Path 7from pathlib import Path
7 8
8import numpy as np 9import numpy as np
@@ -21,7 +22,6 @@ from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 23from slugify import slugify
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24import json
25 25
26from data.csv import CSVDataModule 26from data.csv import CSVDataModule
27 27
@@ -68,7 +68,7 @@ def parse_args():
68 parser.add_argument( 68 parser.add_argument(
69 "--num_class_images", 69 "--num_class_images",
70 type=int, 70 type=int,
71 default=4, 71 default=200,
72 help="How many class images to generate per training image." 72 help="How many class images to generate per training image."
73 ) 73 )
74 parser.add_argument( 74 parser.add_argument(
@@ -140,7 +140,7 @@ def parse_args():
140 parser.add_argument( 140 parser.add_argument(
141 "--lr_scheduler", 141 "--lr_scheduler",
142 type=str, 142 type=str,
143 default="constant", 143 default="linear",
144 help=( 144 help=(
145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
146 ' "constant", "constant_with_warmup"]' 146 ' "constant", "constant_with_warmup"]'
@@ -199,6 +199,12 @@ def parse_args():
199 help="For distributed training: local_rank" 199 help="For distributed training: local_rank"
200 ) 200 )
201 parser.add_argument( 201 parser.add_argument(
202 "--sample_frequency",
203 type=int,
204 default=100,
205 help="How often to save a checkpoint and sample image",
206 )
207 parser.add_argument(
202 "--sample_image_size", 208 "--sample_image_size",
203 type=int, 209 type=int,
204 default=512, 210 default=512,
@@ -366,20 +372,20 @@ class Checkpointer:
366 generator=generator, 372 generator=generator,
367 ) 373 )
368 374
369 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 375 with torch.inference_mode():
370 all_samples = [] 376 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
371 file_path = samples_path.joinpath(pool, f"step_{step}.png") 377 all_samples = []
372 file_path.parent.mkdir(parents=True, exist_ok=True) 378 file_path = samples_path.joinpath(pool, f"step_{step}.png")
379 file_path.parent.mkdir(parents=True, exist_ok=True)
373 380
374 data_enum = enumerate(data) 381 data_enum = enumerate(data)
375 382
376 for i in range(self.sample_batches): 383 for i in range(self.sample_batches):
377 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 384 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
378 prompt = [prompt.format(self.instance_identifier) 385 prompt = [prompt.format(self.instance_identifier)
379 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 386 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
380 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 387 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
381 388
382 with self.accelerator.autocast():
383 samples = pipeline( 389 samples = pipeline(
384 prompt=prompt, 390 prompt=prompt,
385 negative_prompt=nprompt, 391 negative_prompt=nprompt,
@@ -393,15 +399,15 @@ class Checkpointer:
393 output_type='pil' 399 output_type='pil'
394 )["sample"] 400 )["sample"]
395 401
396 all_samples += samples 402 all_samples += samples
397 403
398 del samples 404 del samples
399 405
400 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 406 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
401 image_grid.save(file_path) 407 image_grid.save(file_path)
402 408
403 del all_samples 409 del all_samples
404 del image_grid 410 del image_grid
405 411
406 del unwrapped 412 del unwrapped
407 del scheduler 413 del scheduler
@@ -538,7 +544,7 @@ def main():
538 datamodule.setup() 544 datamodule.setup()
539 545
540 if args.num_class_images != 0: 546 if args.num_class_images != 0:
541 missing_data = [item for item in datamodule.data if not item[1].exists()] 547 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
542 548
543 if len(missing_data) != 0: 549 if len(missing_data) != 0:
544 batched_data = [missing_data[i:i+args.sample_batch_size] 550 batched_data = [missing_data[i:i+args.sample_batch_size]
@@ -558,20 +564,20 @@ def main():
558 pipeline.enable_attention_slicing() 564 pipeline.enable_attention_slicing()
559 pipeline.set_progress_bar_config(dynamic_ncols=True) 565 pipeline.set_progress_bar_config(dynamic_ncols=True)
560 566
561 for batch in batched_data: 567 with torch.inference_mode():
562 image_name = [p[1] for p in batch] 568 for batch in batched_data:
563 prompt = [p[2].format(args.class_identifier) for p in batch] 569 image_name = [p.class_image_path for p in batch]
564 nprompt = [p[3] for p in batch] 570 prompt = [p.prompt.format(args.class_identifier) for p in batch]
571 nprompt = [p.nprompt for p in batch]
565 572
566 with accelerator.autocast():
567 images = pipeline( 573 images = pipeline(
568 prompt=prompt, 574 prompt=prompt,
569 negative_prompt=nprompt, 575 negative_prompt=nprompt,
570 num_inference_steps=args.sample_steps 576 num_inference_steps=args.sample_steps
571 ).images 577 ).images
572 578
573 for i, image in enumerate(images): 579 for i, image in enumerate(images):
574 image.save(image_name[i]) 580 image.save(image_name[i])
575 581
576 del pipeline 582 del pipeline
577 583
@@ -677,6 +683,8 @@ def main():
677 unet.train() 683 unet.train()
678 train_loss = 0.0 684 train_loss = 0.0
679 685
686 sample_checkpoint = False
687
680 for step, batch in enumerate(train_dataloader): 688 for step, batch in enumerate(train_dataloader):
681 with accelerator.accumulate(unet): 689 with accelerator.accumulate(unet):
682 # Convert images to latent space 690 # Convert images to latent space
@@ -737,6 +745,9 @@ def main():
737 745
738 global_step += 1 746 global_step += 1
739 747
748 if global_step % args.sample_frequency == 0:
749 sample_checkpoint = True
750
740 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 751 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
741 local_progress_bar.set_postfix(**logs) 752 local_progress_bar.set_postfix(**logs)
742 753
@@ -783,7 +794,11 @@ def main():
783 794
784 val_loss /= len(val_dataloader) 795 val_loss /= len(val_dataloader)
785 796
786 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 797 accelerator.log({
798 "train/loss": train_loss,
799 "val/loss": val_loss,
800 "lr": lr_scheduler.get_last_lr()[0]
801 }, step=global_step)
787 802
788 local_progress_bar.clear() 803 local_progress_bar.clear()
789 global_progress_bar.clear() 804 global_progress_bar.clear()
@@ -792,7 +807,7 @@ def main():
792 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 807 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
793 min_val_loss = val_loss 808 min_val_loss = val_loss
794 809
795 if accelerator.is_main_process: 810 if sample_checkpoint and accelerator.is_main_process:
796 checkpointer.save_samples( 811 checkpointer.save_samples(
797 global_step, 812 global_step,
798 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 813 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
diff --git a/environment.yaml b/environment.yaml
index c9f498e..5ecc5a8 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -32,6 +32,6 @@ dependencies:
32 - test-tube>=0.7.5 32 - test-tube>=0.7.5
33 - torch-fidelity==0.3.0 33 - torch-fidelity==0.3.0
34 - torchmetrics==0.9.3 34 - torchmetrics==0.9.3
35 - transformers==4.22.1 35 - transformers==4.22.2
36 - triton==2.0.0.dev20220924 36 - triton==2.0.0.dev20220924
37 - xformers==0.0.13 37 - xformers==0.0.13
diff --git a/infer.py b/infer.py
index 6197aa3..a542534 100644
--- a/infer.py
+++ b/infer.py
@@ -5,12 +5,11 @@ import sys
5import shlex 5import shlex
6import cmd 6import cmd
7from pathlib import Path 7from pathlib import Path
8from torch import autocast
9import torch 8import torch
10import json 9import json
11from PIL import Image 10from PIL import Image
12from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler
13from transformers import CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor 12from transformers import CLIPTextModel, CLIPTokenizer
14from slugify import slugify 13from slugify import slugify
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 14from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
@@ -22,7 +21,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
22default_args = { 21default_args = {
23 "model": None, 22 "model": None,
24 "scheduler": "euler_a", 23 "scheduler": "euler_a",
25 "precision": "bf16", 24 "precision": "fp16",
26 "embeddings_dir": "embeddings", 25 "embeddings_dir": "embeddings",
27 "output_dir": "output/inference", 26 "output_dir": "output/inference",
28 "config": None, 27 "config": None,
@@ -260,7 +259,7 @@ def generate(output_dir, pipeline, args):
260 else: 259 else:
261 init_image = None 260 init_image = None
262 261
263 with autocast("cuda"): 262 with torch.autocast("cuda"), torch.inference_mode():
264 for i in range(args.batch_num): 263 for i in range(args.batch_num):
265 pipeline.set_progress_bar_config( 264 pipeline.set_progress_bar_config(
266 desc=f"Batch {i + 1} of {args.batch_num}", 265 desc=f"Batch {i + 1} of {args.batch_num}",
@@ -313,6 +312,9 @@ class CmdParse(cmd.Cmd):
313 args = run_parser(self.parser, default_cmds, elements) 312 args = run_parser(self.parser, default_cmds, elements)
314 except SystemExit: 313 except SystemExit:
315 self.parser.print_help() 314 self.parser.print_help()
315 except Exception as e:
316 print(e)
317 return
316 318
317 if len(args.prompt) == 0: 319 if len(args.prompt) == 0:
318 print('Try again with a prompt!') 320 print('Try again with a prompt!')
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a198cf6..bfecd1c 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -234,7 +234,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
234 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) 234 latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
235 elif isinstance(latents, PIL.Image.Image): 235 elif isinstance(latents, PIL.Image.Image):
236 latents = preprocess(latents, width, height) 236 latents = preprocess(latents, width, height)
237 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist 237 latents = latents.to(device=self.device, dtype=latents_dtype)
238 latent_dist = self.vae.encode(latents).latent_dist
238 latents = latent_dist.sample(generator=generator) 239 latents = latent_dist.sample(generator=generator)
239 latents = 0.18215 * latents 240 latents = 0.18215 * latents
240 241
@@ -249,7 +250,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
249 timesteps = torch.tensor([timesteps] * batch_size, device=self.device) 250 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
250 251
251 # add noise to latents using the timesteps 252 # add noise to latents using the timesteps
252 noise = torch.randn(latents.shape, generator=generator, device=self.device) 253 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
253 latents = self.scheduler.add_noise(latents, noise, timesteps) 254 latents = self.scheduler.add_noise(latents, noise, timesteps)
254 else: 255 else:
255 if latents.shape != latents_shape: 256 if latents.shape != latents_shape:
diff --git a/textual_inversion.py b/textual_inversion.py
index 4f2de9e..09871d4 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -4,6 +4,7 @@ import math
4import os 4import os
5import datetime 5import datetime
6import logging 6import logging
7import json
7from pathlib import Path 8from pathlib import Path
8 9
9import numpy as np 10import numpy as np
@@ -22,8 +23,6 @@ from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 24from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json
26import os
27 26
28from data.csv import CSVDataModule 27from data.csv import CSVDataModule
29 28
@@ -70,7 +69,7 @@ def parse_args():
70 parser.add_argument( 69 parser.add_argument(
71 "--num_class_images", 70 "--num_class_images",
72 type=int, 71 type=int,
73 default=4, 72 default=200,
74 help="How many class images to generate per training image." 73 help="How many class images to generate per training image."
75 ) 74 )
76 parser.add_argument( 75 parser.add_argument(
@@ -141,7 +140,7 @@ def parse_args():
141 parser.add_argument( 140 parser.add_argument(
142 "--lr_scheduler", 141 "--lr_scheduler",
143 type=str, 142 type=str,
144 default="constant", 143 default="linear",
145 help=( 144 help=(
146 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
147 ' "constant", "constant_with_warmup"]' 146 ' "constant", "constant_with_warmup"]'
@@ -402,20 +401,20 @@ class Checkpointer:
402 generator=generator, 401 generator=generator,
403 ) 402 )
404 403
405 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 404 with torch.inference_mode():
406 all_samples = [] 405 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
407 file_path = samples_path.joinpath(pool, f"step_{step}.png") 406 all_samples = []
408 file_path.parent.mkdir(parents=True, exist_ok=True) 407 file_path = samples_path.joinpath(pool, f"step_{step}.png")
408 file_path.parent.mkdir(parents=True, exist_ok=True)
409 409
410 data_enum = enumerate(data) 410 data_enum = enumerate(data)
411 411
412 for i in range(self.sample_batches): 412 for i in range(self.sample_batches):
413 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 413 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
414 prompt = [prompt.format(self.placeholder_token) 414 prompt = [prompt.format(self.placeholder_token)
415 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 415 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
416 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 416 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
417 417
418 with self.accelerator.autocast():
419 samples = pipeline( 418 samples = pipeline(
420 prompt=prompt, 419 prompt=prompt,
421 negative_prompt=nprompt, 420 negative_prompt=nprompt,
@@ -429,15 +428,15 @@ class Checkpointer:
429 output_type='pil' 428 output_type='pil'
430 )["sample"] 429 )["sample"]
431 430
432 all_samples += samples 431 all_samples += samples
433 432
434 del samples 433 del samples
435 434
436 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 435 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
437 image_grid.save(file_path) 436 image_grid.save(file_path)
438 437
439 del all_samples 438 del all_samples
440 del image_grid 439 del image_grid
441 440
442 del unwrapped 441 del unwrapped
443 del scheduler 442 del scheduler
@@ -623,7 +622,7 @@ def main():
623 datamodule.setup() 622 datamodule.setup()
624 623
625 if args.num_class_images != 0: 624 if args.num_class_images != 0:
626 missing_data = [item for item in datamodule.data if not item[1].exists()] 625 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
627 626
628 if len(missing_data) != 0: 627 if len(missing_data) != 0:
629 batched_data = [missing_data[i:i+args.sample_batch_size] 628 batched_data = [missing_data[i:i+args.sample_batch_size]
@@ -643,20 +642,20 @@ def main():
643 pipeline.enable_attention_slicing() 642 pipeline.enable_attention_slicing()
644 pipeline.set_progress_bar_config(dynamic_ncols=True) 643 pipeline.set_progress_bar_config(dynamic_ncols=True)
645 644
646 for batch in batched_data: 645 with torch.inference_mode():
647 image_name = [p[1] for p in batch] 646 for batch in batched_data:
648 prompt = [p[2].format(args.initializer_token) for p in batch] 647 image_name = [p.class_image_path for p in batch]
649 nprompt = [p[3] for p in batch] 648 prompt = [p.prompt.format(args.initializer_token) for p in batch]
649 nprompt = [p.nprompt for p in batch]
650 650
651 with accelerator.autocast():
652 images = pipeline( 651 images = pipeline(
653 prompt=prompt, 652 prompt=prompt,
654 negative_prompt=nprompt, 653 negative_prompt=nprompt,
655 num_inference_steps=args.sample_steps 654 num_inference_steps=args.sample_steps
656 ).images 655 ).images
657 656
658 for i, image in enumerate(images): 657 for i, image in enumerate(images):
659 image.save(image_name[i]) 658 image.save(image_name[i])
660 659
661 del pipeline 660 del pipeline
662 661