summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 21:28:52 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 21:28:52 +0200
commit46b6c09a18b41edff77c6881529b66733d788abe (patch)
tree670e7cdda37ba7a010b570398a63dd38e357b6ce
parentSmall perf improvements (diff)
downloadtextual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.gz
textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.tar.bz2
textual-inversion-diff-46b6c09a18b41edff77c6881529b66733d788abe.zip
Dreambooth: Generate specialized class images from input prompts
-rw-r--r--data/dreambooth/csv.py112
-rw-r--r--data/dreambooth/prompt.py4
-rw-r--r--data/textual_inversion/csv.py3
-rw-r--r--dreambooth.py168
-rw-r--r--textual_inversion.py6
5 files changed, 129 insertions, 164 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index c0b0067..4ebdc13 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -13,13 +13,11 @@ class CSVDataModule(pl.LightningDataModule):
13 batch_size, 13 batch_size,
14 data_file, 14 data_file,
15 tokenizer, 15 tokenizer,
16 instance_prompt, 16 instance_identifier,
17 class_data_root=None, 17 class_identifier=None,
18 class_prompt=None,
19 size=512, 18 size=512,
20 repeats=100, 19 repeats=100,
21 interpolation="bicubic", 20 interpolation="bicubic",
22 identifier="*",
23 center_crop=False, 21 center_crop=False,
24 valid_set_size=None, 22 valid_set_size=None,
25 generator=None, 23 generator=None,
@@ -32,13 +30,14 @@ class CSVDataModule(pl.LightningDataModule):
32 raise ValueError("data_file must be a file") 30 raise ValueError("data_file must be a file")
33 31
34 self.data_root = self.data_file.parent 32 self.data_root = self.data_file.parent
33 self.class_root = self.data_root.joinpath("db_cls")
34 self.class_root.mkdir(parents=True, exist_ok=True)
35
35 self.tokenizer = tokenizer 36 self.tokenizer = tokenizer
36 self.instance_prompt = instance_prompt 37 self.instance_identifier = instance_identifier
37 self.class_data_root = class_data_root 38 self.class_identifier = class_identifier
38 self.class_prompt = class_prompt
39 self.size = size 39 self.size = size
40 self.repeats = repeats 40 self.repeats = repeats
41 self.identifier = identifier
42 self.center_crop = center_crop 41 self.center_crop = center_crop
43 self.interpolation = interpolation 42 self.interpolation = interpolation
44 self.valid_set_size = valid_set_size 43 self.valid_set_size = valid_set_size
@@ -48,30 +47,36 @@ class CSVDataModule(pl.LightningDataModule):
48 47
49 def prepare_data(self): 48 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 49 metadata = pd.read_csv(self.data_file)
51 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 50 instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values]
51 class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values]
52 prompts = metadata['prompt'].values 52 prompts = metadata['prompt'].values
53 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) 53 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths)
54 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) 54 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths)
55 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] 55 self.data = [(i, c, p, n)
56 for i, c, p, n, s
57 in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips)
58 if s != "x"]
56 59
57 def setup(self, stage=None): 60 def setup(self, stage=None):
58 valid_set_size = int(len(self.data_full) * 0.2) 61 valid_set_size = int(len(self.data) * 0.2)
59 if self.valid_set_size: 62 if self.valid_set_size:
60 valid_set_size = min(valid_set_size, self.valid_set_size) 63 valid_set_size = min(valid_set_size, self.valid_set_size)
61 train_set_size = len(self.data_full) - valid_set_size 64 valid_set_size = max(valid_set_size, 1)
62 65 train_set_size = len(self.data) - valid_set_size
63 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) 66
64 67 self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator)
65 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, 68
66 class_data_root=self.class_data_root, class_prompt=self.class_prompt, 69 train_dataset = CSVDataset(self.data_train, self.tokenizer,
67 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 70 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
68 center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) 71 size=self.size, interpolation=self.interpolation,
69 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, 72 center_crop=self.center_crop, repeats=self.repeats)
70 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 73 val_dataset = CSVDataset(self.data_val, self.tokenizer,
71 center_crop=self.center_crop, batch_size=self.batch_size) 74 instance_identifier=self.instance_identifier,
72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 75 size=self.size, interpolation=self.interpolation,
76 center_crop=self.center_crop, repeats=self.repeats)
77 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True,
73 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 78 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 79 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True,
75 pin_memory=True, collate_fn=self.collate_fn) 80 pin_memory=True, collate_fn=self.collate_fn)
76 81
77 def train_dataloader(self): 82 def train_dataloader(self):
@@ -85,39 +90,23 @@ class CSVDataset(Dataset):
85 def __init__(self, 90 def __init__(self,
86 data, 91 data,
87 tokenizer, 92 tokenizer,
88 instance_prompt, 93 instance_identifier,
89 class_data_root=None, 94 class_identifier=None,
90 class_prompt=None,
91 size=512, 95 size=512,
92 repeats=1, 96 repeats=1,
93 interpolation="bicubic", 97 interpolation="bicubic",
94 identifier="*",
95 center_crop=False, 98 center_crop=False,
96 batch_size=1,
97 ): 99 ):
98 100
99 self.data = data 101 self.data = data
100 self.tokenizer = tokenizer 102 self.tokenizer = tokenizer
101 self.instance_prompt = instance_prompt 103 self.instance_identifier = instance_identifier
102 self.identifier = identifier 104 self.class_identifier = class_identifier
103 self.batch_size = batch_size
104 self.cache = {} 105 self.cache = {}
105 106
106 self.num_instance_images = len(self.data) 107 self.num_instance_images = len(self.data)
107 self._length = self.num_instance_images * repeats 108 self._length = self.num_instance_images * repeats
108 109
109 if class_data_root is not None:
110 self.class_data_root = Path(class_data_root)
111 self.class_data_root.mkdir(parents=True, exist_ok=True)
112
113 self.class_images = list(self.class_data_root.iterdir())
114 self.num_class_images = len(self.class_images)
115 self._length = max(self.num_class_images, self.num_instance_images)
116
117 self.class_prompt = class_prompt
118 else:
119 self.class_data_root = None
120
121 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, 110 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
122 "bilinear": transforms.InterpolationMode.BILINEAR, 111 "bilinear": transforms.InterpolationMode.BILINEAR,
123 "bicubic": transforms.InterpolationMode.BICUBIC, 112 "bicubic": transforms.InterpolationMode.BICUBIC,
@@ -134,46 +123,49 @@ class CSVDataset(Dataset):
134 ) 123 )
135 124
136 def __len__(self): 125 def __len__(self):
137 return math.ceil(self._length / self.batch_size) * self.batch_size 126 return self._length
138 127
139 def get_example(self, i): 128 def get_example(self, i):
140 image_path, prompt, nprompt = self.data[i % self.num_instance_images] 129 instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images]
141 130
142 if image_path in self.cache: 131 if instance_image_path in self.cache:
143 return self.cache[image_path] 132 return self.cache[instance_image_path]
144 133
145 example = {} 134 example = {}
146 135
147 instance_image = Image.open(image_path) 136 example["prompts"] = prompt
137 example["nprompts"] = nprompt
138
139 instance_image = Image.open(instance_image_path)
148 if not instance_image.mode == "RGB": 140 if not instance_image.mode == "RGB":
149 instance_image = instance_image.convert("RGB") 141 instance_image = instance_image.convert("RGB")
150 142
151 prompt = prompt.format(self.identifier) 143 instance_prompt = prompt.format(self.instance_identifier)
152 144
153 example["prompts"] = prompt
154 example["nprompts"] = nprompt
155 example["instance_images"] = instance_image 145 example["instance_images"] = instance_image
156 example["instance_prompt_ids"] = self.tokenizer( 146 example["instance_prompt_ids"] = self.tokenizer(
157 self.instance_prompt, 147 instance_prompt,
158 padding="do_not_pad", 148 padding="do_not_pad",
159 truncation=True, 149 truncation=True,
160 max_length=self.tokenizer.model_max_length, 150 max_length=self.tokenizer.model_max_length,
161 ).input_ids 151 ).input_ids
162 152
163 if self.class_data_root: 153 if self.class_identifier:
164 class_image = Image.open(self.class_images[i % self.num_class_images]) 154 class_image = Image.open(class_image_path)
165 if not class_image.mode == "RGB": 155 if not class_image.mode == "RGB":
166 class_image = class_image.convert("RGB") 156 class_image = class_image.convert("RGB")
167 157
158 class_prompt = prompt.format(self.class_identifier)
159
168 example["class_images"] = class_image 160 example["class_images"] = class_image
169 example["class_prompt_ids"] = self.tokenizer( 161 example["class_prompt_ids"] = self.tokenizer(
170 self.class_prompt, 162 class_prompt,
171 padding="do_not_pad", 163 padding="do_not_pad",
172 truncation=True, 164 truncation=True,
173 max_length=self.tokenizer.model_max_length, 165 max_length=self.tokenizer.model_max_length,
174 ).input_ids 166 ).input_ids
175 167
176 self.cache[image_path] = example 168 self.cache[instance_image_path] = example
177 return example 169 return example
178 170
179 def __getitem__(self, i): 171 def __getitem__(self, i):
@@ -185,7 +177,7 @@ class CSVDataset(Dataset):
185 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 177 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
186 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 178 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
187 179
188 if self.class_data_root: 180 if self.class_identifier:
189 example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) 181 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
190 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] 182 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
191 183
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
index 34f510d..b3a83ce 100644
--- a/data/dreambooth/prompt.py
+++ b/data/dreambooth/prompt.py
@@ -2,8 +2,9 @@ from torch.utils.data import Dataset
2 2
3 3
4class PromptDataset(Dataset): 4class PromptDataset(Dataset):
5 def __init__(self, prompt, num_samples): 5 def __init__(self, prompt, nprompt, num_samples):
6 self.prompt = prompt 6 self.prompt = prompt
7 self.nprompt = nprompt
7 self.num_samples = num_samples 8 self.num_samples = num_samples
8 9
9 def __len__(self): 10 def __len__(self):
@@ -12,5 +13,6 @@ class PromptDataset(Dataset):
12 def __getitem__(self, index): 13 def __getitem__(self, index):
13 example = {} 14 example = {}
14 example["prompt"] = self.prompt 15 example["prompt"] = self.prompt
16 example["nprompt"] = self.nprompt
15 example["index"] = index 17 example["index"] = index
16 return example 18 return example
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 852b1cb..4c5e27e 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -52,13 +52,14 @@ class CSVDataModule(pl.LightningDataModule):
52 valid_set_size = int(len(self.data_full) * 0.2) 52 valid_set_size = int(len(self.data_full) * 0.2)
53 if self.valid_set_size: 53 if self.valid_set_size:
54 valid_set_size = min(valid_set_size, self.valid_set_size) 54 valid_set_size = min(valid_set_size, self.valid_set_size)
55 valid_set_size = max(valid_set_size, 1)
55 train_set_size = len(self.data_full) - valid_set_size 56 train_set_size = len(self.data_full) - valid_set_size
56 57
57 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) 58 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
58 59
59 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 60 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
60 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 61 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 62 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
62 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 63 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) 64 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) 65 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
diff --git a/dreambooth.py b/dreambooth.py
index 9d6b8d6..2fe89ec 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -13,13 +13,12 @@ import torch.utils.checkpoint
13from accelerate import Accelerator 13from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
17from schedulers.scheduling_euler_a import EulerAScheduler 17from schedulers.scheduling_euler_a import EulerAScheduler
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
19from pipelines.stable_diffusion.no_check import NoCheck
20from PIL import Image 19from PIL import Image
21from tqdm.auto import tqdm 20from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 21from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 22from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json 24import json
@@ -56,7 +55,13 @@ def parse_args():
56 help="A folder containing the training data." 55 help="A folder containing the training data."
57 ) 56 )
58 parser.add_argument( 57 parser.add_argument(
59 "--identifier", 58 "--instance_identifier",
59 type=str,
60 default=None,
61 help="A token to use as a placeholder for the concept.",
62 )
63 parser.add_argument(
64 "--class_identifier",
60 type=str, 65 type=str,
61 default=None, 66 default=None,
62 help="A token to use as a placeholder for the concept.", 67 help="A token to use as a placeholder for the concept.",
@@ -218,12 +223,6 @@ def parse_args():
218 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 223 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
219 ) 224 )
220 parser.add_argument( 225 parser.add_argument(
221 "--instance_prompt",
222 type=str,
223 default=None,
224 help="The prompt with identifier specifing the instance",
225 )
226 parser.add_argument(
227 "--class_data_dir", 226 "--class_data_dir",
228 type=str, 227 type=str,
229 default=None, 228 default=None,
@@ -231,12 +230,6 @@ def parse_args():
231 help="A folder containing the training data of class images.", 230 help="A folder containing the training data of class images.",
232 ) 231 )
233 parser.add_argument( 232 parser.add_argument(
234 "--class_prompt",
235 type=str,
236 default=None,
237 help="The prompt to specify images in the same class as provided intance images.",
238 )
239 parser.add_argument(
240 "--prior_loss_weight", 233 "--prior_loss_weight",
241 type=float, 234 type=float,
242 default=1.0, 235 default=1.0,
@@ -255,15 +248,6 @@ def parse_args():
255 help="Max gradient norm." 248 help="Max gradient norm."
256 ) 249 )
257 parser.add_argument( 250 parser.add_argument(
258 "--num_class_images",
259 type=int,
260 default=100,
261 help=(
262 "Minimal class images for prior perversation loss. If not have enough images, additional images will be"
263 " sampled with class_prompt."
264 ),
265 )
266 parser.add_argument(
267 "--config", 251 "--config",
268 type=str, 252 type=str,
269 default=None, 253 default=None,
@@ -286,21 +270,12 @@ def parse_args():
286 if args.pretrained_model_name_or_path is None: 270 if args.pretrained_model_name_or_path is None:
287 raise ValueError("You must specify --pretrained_model_name_or_path") 271 raise ValueError("You must specify --pretrained_model_name_or_path")
288 272
289 if args.instance_prompt is None: 273 if args.instance_identifier is None:
290 raise ValueError("You must specify --instance_prompt") 274 raise ValueError("You must specify --instance_identifier")
291
292 if args.identifier is None:
293 raise ValueError("You must specify --identifier")
294 275
295 if args.output_dir is None: 276 if args.output_dir is None:
296 raise ValueError("You must specify --output_dir") 277 raise ValueError("You must specify --output_dir")
297 278
298 if args.with_prior_preservation:
299 if args.class_data_dir is None:
300 raise ValueError("You must specify --class_data_dir")
301 if args.class_prompt is None:
302 raise ValueError("You must specify --class_prompt")
303
304 return args 279 return args
305 280
306 281
@@ -443,7 +418,7 @@ def main():
443 args = parse_args() 418 args = parse_args()
444 419
445 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 420 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
446 basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) 421 basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now)
447 basepath.mkdir(parents=True, exist_ok=True) 422 basepath.mkdir(parents=True, exist_ok=True)
448 423
449 accelerator = Accelerator( 424 accelerator = Accelerator(
@@ -488,47 +463,6 @@ def main():
488 freeze_params(vae.parameters()) 463 freeze_params(vae.parameters())
489 freeze_params(text_encoder.parameters()) 464 freeze_params(text_encoder.parameters())
490 465
491 # Generate class images, if necessary
492 if args.with_prior_preservation:
493 class_images_dir = Path(args.class_data_dir)
494 class_images_dir.mkdir(parents=True, exist_ok=True)
495 cur_class_images = len(list(class_images_dir.iterdir()))
496
497 if cur_class_images < args.num_class_images:
498 scheduler = EulerAScheduler(
499 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
500 )
501
502 pipeline = VlpnStableDiffusion(
503 text_encoder=text_encoder,
504 vae=vae,
505 unet=unet,
506 tokenizer=tokenizer,
507 scheduler=scheduler,
508 ).to(accelerator.device)
509 pipeline.enable_attention_slicing()
510 pipeline.set_progress_bar_config(disable=True)
511
512 num_new_images = args.num_class_images - cur_class_images
513 logger.info(f"Number of class images to sample: {num_new_images}.")
514
515 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
516 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
517
518 sample_dataloader = accelerator.prepare(sample_dataloader)
519
520 for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process):
521 with accelerator.autocast():
522 images = pipeline(example["prompt"]).images
523
524 for i, image in enumerate(images):
525 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
526
527 del pipeline
528
529 if torch.cuda.is_available():
530 torch.cuda.empty_cache()
531
532 if args.scale_lr: 466 if args.scale_lr:
533 args.learning_rate = ( 467 args.learning_rate = (
534 args.learning_rate * args.gradient_accumulation_steps * 468 args.learning_rate * args.gradient_accumulation_steps *
@@ -564,6 +498,7 @@ def main():
564 498
565 def collate_fn(examples): 499 def collate_fn(examples):
566 prompts = [example["prompts"] for example in examples] 500 prompts = [example["prompts"] for example in examples]
501 nprompts = [example["nprompts"] for example in examples]
567 input_ids = [example["instance_prompt_ids"] for example in examples] 502 input_ids = [example["instance_prompt_ids"] for example in examples]
568 pixel_values = [example["instance_images"] for example in examples] 503 pixel_values = [example["instance_images"] for example in examples]
569 504
@@ -579,6 +514,7 @@ def main():
579 514
580 batch = { 515 batch = {
581 "prompts": prompts, 516 "prompts": prompts,
517 "nprompts": nprompts,
582 "input_ids": input_ids, 518 "input_ids": input_ids,
583 "pixel_values": pixel_values, 519 "pixel_values": pixel_values,
584 } 520 }
@@ -588,11 +524,9 @@ def main():
588 data_file=args.train_data_file, 524 data_file=args.train_data_file,
589 batch_size=args.train_batch_size, 525 batch_size=args.train_batch_size,
590 tokenizer=tokenizer, 526 tokenizer=tokenizer,
591 instance_prompt=args.instance_prompt, 527 instance_identifier=args.instance_identifier,
592 class_data_root=args.class_data_dir if args.with_prior_preservation else None, 528 class_identifier=args.class_identifier,
593 class_prompt=args.class_prompt,
594 size=args.resolution, 529 size=args.resolution,
595 identifier=args.identifier,
596 repeats=args.repeats, 530 repeats=args.repeats,
597 center_crop=args.center_crop, 531 center_crop=args.center_crop,
598 valid_set_size=args.sample_batch_size*args.sample_batches, 532 valid_set_size=args.sample_batch_size*args.sample_batches,
@@ -601,6 +535,46 @@ def main():
601 datamodule.prepare_data() 535 datamodule.prepare_data()
602 datamodule.setup() 536 datamodule.setup()
603 537
538 if args.class_identifier:
539 missing_data = [item for item in datamodule.data if not item[1].exists()]
540
541 if len(missing_data) != 0:
542 batched_data = [missing_data[i:i+args.sample_batch_size]
543 for i in range(0, len(missing_data), args.sample_batch_size)]
544
545 scheduler = EulerAScheduler(
546 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
547 )
548
549 pipeline = VlpnStableDiffusion(
550 text_encoder=text_encoder,
551 vae=vae,
552 unet=unet,
553 tokenizer=tokenizer,
554 scheduler=scheduler,
555 ).to(accelerator.device)
556 pipeline.enable_attention_slicing()
557
558 for batch in batched_data:
559 image_name = [p[1] for p in batch]
560 prompt = [p[2] for p in batch]
561 nprompt = [p[3] for p in batch]
562
563 with accelerator.autocast():
564 images = pipeline(
565 prompt=prompt,
566 negative_prompt=nprompt,
567 num_inference_steps=args.sample_steps
568 ).images
569
570 for i, image in enumerate(images):
571 image.save(image_name[i])
572
573 del pipeline
574
575 if torch.cuda.is_available():
576 torch.cuda.empty_cache()
577
604 train_dataloader = datamodule.train_dataloader() 578 train_dataloader = datamodule.train_dataloader()
605 val_dataloader = datamodule.val_dataloader() 579 val_dataloader = datamodule.val_dataloader()
606 580
@@ -718,23 +692,22 @@ def main():
718 # Predict the noise residual 692 # Predict the noise residual
719 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 693 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
720 694
721 with accelerator.autocast(): 695 if args.with_prior_preservation:
722 if args.with_prior_preservation: 696 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
723 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. 697 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
724 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) 698 noise, noise_prior = torch.chunk(noise, 2, dim=0)
725 noise, noise_prior = torch.chunk(noise, 2, dim=0)
726 699
727 # Compute instance loss 700 # Compute instance loss
728 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 701 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
729 702
730 # Compute prior loss 703 # Compute prior loss
731 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, 704 prior_loss = F.mse_loss(noise_pred_prior, noise_prior,
732 reduction="none").mean([1, 2, 3]).mean() 705 reduction="none").mean([1, 2, 3]).mean()
733 706
734 # Add the prior loss to the instance loss. 707 # Add the prior loss to the instance loss.
735 loss = loss + args.prior_loss_weight * prior_loss 708 loss = loss + args.prior_loss_weight * prior_loss
736 else: 709 else:
737 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 710 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
738 711
739 accelerator.backward(loss) 712 accelerator.backward(loss)
740 if accelerator.sync_gradients: 713 if accelerator.sync_gradients:
@@ -786,8 +759,7 @@ def main():
786 759
787 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 760 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
788 761
789 with accelerator.autocast(): 762 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
790 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
791 763
792 loss = loss.detach().item() 764 loss = loss.detach().item()
793 val_loss += loss 765 val_loss += loss
diff --git a/textual_inversion.py b/textual_inversion.py
index 5fc2338..4c4da29 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -694,8 +694,7 @@ def main():
694 # Predict the noise residual 694 # Predict the noise residual
695 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 695 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
696 696
697 with accelerator.autocast(): 697 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
698 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
699 698
700 accelerator.backward(loss) 699 accelerator.backward(loss)
701 700
@@ -766,8 +765,7 @@ def main():
766 765
767 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 766 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
768 767
769 with accelerator.autocast(): 768 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
770 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
771 769
772 loss = loss.detach().item() 770 loss = loss.detach().item()
773 val_loss += loss 771 val_loss += loss