summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
commitc90099f06e0b461660b326fb6d86b69d86e78289 (patch)
treedf4ce274eed8f2a89bbd12f1a19c685ceac58ff2
parentFixed euler_a generator argument (diff)
downloadtextual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip
Added negative prompt support for training scripts
-rw-r--r--data/dreambooth/csv.py15
-rw-r--r--data/textual_inversion/csv.py17
-rw-r--r--dreambooth.py81
-rw-r--r--textual_inversion.py83
4 files changed, 63 insertions, 133 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 08ed49c..71aa1eb 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule):
49 def prepare_data(self): 49 def prepare_data(self):
50 metadata = pd.read_csv(self.data_file) 50 metadata = pd.read_csv(self.data_file)
51 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 51 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
52 captions = [caption for caption in metadata['caption'].values] 52 prompts = metadata['prompt'].values
53 skips = [skip for skip in metadata['skip'].values] 53 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
54 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] 54 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
55 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
55 56
56 def setup(self, stage=None): 57 def setup(self, stage=None):
57 valid_set_size = int(len(self.data_full) * 0.2) 58 valid_set_size = int(len(self.data_full) * 0.2)
@@ -135,7 +136,7 @@ class CSVDataset(Dataset):
135 return math.ceil(self._length / self.batch_size) * self.batch_size 136 return math.ceil(self._length / self.batch_size) * self.batch_size
136 137
137 def get_example(self, i): 138 def get_example(self, i):
138 image_path, text = self.data[i % self.num_instance_images] 139 image_path, prompt, nprompt = self.data[i % self.num_instance_images]
139 140
140 if image_path in self.cache: 141 if image_path in self.cache:
141 return self.cache[image_path] 142 return self.cache[image_path]
@@ -146,9 +147,10 @@ class CSVDataset(Dataset):
146 if not instance_image.mode == "RGB": 147 if not instance_image.mode == "RGB":
147 instance_image = instance_image.convert("RGB") 148 instance_image = instance_image.convert("RGB")
148 149
149 text = text.format(self.identifier) 150 prompt = prompt.format(self.identifier)
150 151
151 example["prompts"] = text 152 example["prompts"] = prompt
153 example["nprompts"] = nprompt
152 example["instance_images"] = instance_image 154 example["instance_images"] = instance_image
153 example["instance_prompt_ids"] = self.tokenizer( 155 example["instance_prompt_ids"] = self.tokenizer(
154 self.instance_prompt, 156 self.instance_prompt,
@@ -178,6 +180,7 @@ class CSVDataset(Dataset):
178 unprocessed_example = self.get_example(i) 180 unprocessed_example = self.get_example(i)
179 181
180 example["prompts"] = unprocessed_example["prompts"] 182 example["prompts"] = unprocessed_example["prompts"]
183 example["nprompts"] = unprocessed_example["nprompts"]
181 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 184 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
182 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] 185 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
183 186
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 3ac57df..64f0c28 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule):
43 def prepare_data(self): 43 def prepare_data(self):
44 metadata = pd.read_csv(self.data_file) 44 metadata = pd.read_csv(self.data_file)
45 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 45 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
46 captions = [caption for caption in metadata['caption'].values] 46 prompts = metadata['prompt'].values
47 skips = [skip for skip in metadata['skip'].values] 47 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
48 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] 48 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
49 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
49 50
50 def setup(self, stage=None): 51 def setup(self, stage=None):
51 valid_set_size = int(len(self.data_full) * 0.2) 52 valid_set_size = int(len(self.data_full) * 0.2)
@@ -109,7 +110,7 @@ class CSVDataset(Dataset):
109 return math.ceil(self._length / self.batch_size) * self.batch_size 110 return math.ceil(self._length / self.batch_size) * self.batch_size
110 111
111 def get_example(self, i): 112 def get_example(self, i):
112 image_path, text = self.data[i % self.num_instance_images] 113 image_path, prompt, nprompt = self.data[i % self.num_instance_images]
113 114
114 if image_path in self.cache: 115 if image_path in self.cache:
115 return self.cache[image_path] 116 return self.cache[image_path]
@@ -120,12 +121,13 @@ class CSVDataset(Dataset):
120 if not instance_image.mode == "RGB": 121 if not instance_image.mode == "RGB":
121 instance_image = instance_image.convert("RGB") 122 instance_image = instance_image.convert("RGB")
122 123
123 text = text.format(self.placeholder_token) 124 prompt = prompt.format(self.placeholder_token)
124 125
125 example["prompts"] = text 126 example["prompts"] = prompt
127 example["nprompts"] = nprompt
126 example["pixel_values"] = instance_image 128 example["pixel_values"] = instance_image
127 example["input_ids"] = self.tokenizer( 129 example["input_ids"] = self.tokenizer(
128 text, 130 prompt,
129 padding="max_length", 131 padding="max_length",
130 truncation=True, 132 truncation=True,
131 max_length=self.tokenizer.model_max_length, 133 max_length=self.tokenizer.model_max_length,
@@ -140,6 +142,7 @@ class CSVDataset(Dataset):
140 unprocessed_example = self.get_example(i) 142 unprocessed_example = self.get_example(i)
141 143
142 example["prompts"] = unprocessed_example["prompts"] 144 example["prompts"] = unprocessed_example["prompts"]
145 example["nprompts"] = unprocessed_example["nprompts"]
143 example["input_ids"] = unprocessed_example["input_ids"] 146 example["input_ids"] = unprocessed_example["input_ids"]
144 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) 147 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
145 148
diff --git a/dreambooth.py b/dreambooth.py
index 75602dc..5fbf172 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -191,16 +191,10 @@ def parse_args():
191 help="Size of sample images", 191 help="Size of sample images",
192 ) 192 )
193 parser.add_argument( 193 parser.add_argument(
194 "--stable_sample_batches", 194 "--sample_batches",
195 type=int, 195 type=int,
196 default=1, 196 default=1,
197 help="Number of fixed seed sample batches to generate per checkpoint", 197 help="Number of sample batches to generate per checkpoint",
198 )
199 parser.add_argument(
200 "--random_sample_batches",
201 type=int,
202 default=1,
203 help="Number of random seed sample batches to generate per checkpoint",
204 ) 198 )
205 parser.add_argument( 199 parser.add_argument(
206 "--sample_batch_size", 200 "--sample_batch_size",
@@ -331,9 +325,8 @@ class Checkpointer:
331 text_encoder, 325 text_encoder,
332 output_dir, 326 output_dir,
333 sample_image_size, 327 sample_image_size,
334 random_sample_batches, 328 sample_batches,
335 sample_batch_size, 329 sample_batch_size,
336 stable_sample_batches,
337 seed 330 seed
338 ): 331 ):
339 self.datamodule = datamodule 332 self.datamodule = datamodule
@@ -345,9 +338,8 @@ class Checkpointer:
345 self.output_dir = output_dir 338 self.output_dir = output_dir
346 self.sample_image_size = sample_image_size 339 self.sample_image_size = sample_image_size
347 self.seed = seed 340 self.seed = seed
348 self.random_sample_batches = random_sample_batches 341 self.sample_batches = sample_batches
349 self.sample_batch_size = sample_batch_size 342 self.sample_batch_size = sample_batch_size
350 self.stable_sample_batches = stable_sample_batches
351 343
352 @torch.no_grad() 344 @torch.no_grad()
353 def checkpoint(self): 345 def checkpoint(self):
@@ -396,63 +388,33 @@ class Checkpointer:
396 train_data = self.datamodule.train_dataloader() 388 train_data = self.datamodule.train_dataloader()
397 val_data = self.datamodule.val_dataloader() 389 val_data = self.datamodule.val_dataloader()
398 390
399 if self.stable_sample_batches > 0: 391 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
400 stable_latents = torch.randn( 392 stable_latents = torch.randn(
401 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 393 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
402 device=pipeline.device, 394 device=pipeline.device,
403 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), 395 generator=generator,
404 ) 396 )
405
406 all_samples = []
407 file_path = samples_path.joinpath("stable", f"step_{step}.png")
408 file_path.parent.mkdir(parents=True, exist_ok=True)
409
410 data_enum = enumerate(val_data)
411
412 # Generate and save stable samples
413 for i in range(0, self.stable_sample_batches):
414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
416
417 with self.accelerator.autocast():
418 samples = pipeline(
419 prompt=prompt,
420 height=self.sample_image_size,
421 latents=stable_latents[:len(prompt)],
422 width=self.sample_image_size,
423 guidance_scale=guidance_scale,
424 eta=eta,
425 num_inference_steps=num_inference_steps,
426 output_type='pil'
427 )["sample"]
428
429 all_samples += samples
430
431 del samples
432
433 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
434 image_grid.save(file_path)
435
436 del all_samples
437 del image_grid
438 del stable_latents
439 397
440 for data, pool in [(val_data, "val"), (train_data, "train")]: 398 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
441 all_samples = [] 399 all_samples = []
442 file_path = samples_path.joinpath(pool, f"step_{step}.png") 400 file_path = samples_path.joinpath(pool, f"step_{step}.png")
443 file_path.parent.mkdir(parents=True, exist_ok=True) 401 file_path.parent.mkdir(parents=True, exist_ok=True)
444 402
445 data_enum = enumerate(data) 403 data_enum = enumerate(data)
446 404
447 for i in range(0, self.random_sample_batches): 405 for i in range(self.sample_batches):
448 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 406 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
449 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 407 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
408 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
450 409
451 with self.accelerator.autocast(): 410 with self.accelerator.autocast():
452 samples = pipeline( 411 samples = pipeline(
453 prompt=prompt, 412 prompt=prompt,
413 negative_prompt=nprompt,
454 height=self.sample_image_size, 414 height=self.sample_image_size,
455 width=self.sample_image_size, 415 width=self.sample_image_size,
416 latents=latents[:len(prompt)] if latents is not None else None,
417 generator=generator if latents is not None else None,
456 guidance_scale=guidance_scale, 418 guidance_scale=guidance_scale,
457 eta=eta, 419 eta=eta,
458 num_inference_steps=num_inference_steps, 420 num_inference_steps=num_inference_steps,
@@ -463,7 +425,7 @@ class Checkpointer:
463 425
464 del samples 426 del samples
465 427
466 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 428 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
467 image_grid.save(file_path) 429 image_grid.save(file_path)
468 430
469 del all_samples 431 del all_samples
@@ -630,7 +592,7 @@ def main():
630 identifier=args.identifier, 592 identifier=args.identifier,
631 repeats=args.repeats, 593 repeats=args.repeats,
632 center_crop=args.center_crop, 594 center_crop=args.center_crop,
633 valid_set_size=args.sample_batch_size*args.stable_sample_batches, 595 valid_set_size=args.sample_batch_size*args.sample_batches,
634 collate_fn=collate_fn) 596 collate_fn=collate_fn)
635 597
636 datamodule.prepare_data() 598 datamodule.prepare_data()
@@ -649,8 +611,7 @@ def main():
649 output_dir=basepath, 611 output_dir=basepath,
650 sample_image_size=args.sample_image_size, 612 sample_image_size=args.sample_image_size,
651 sample_batch_size=args.sample_batch_size, 613 sample_batch_size=args.sample_batch_size,
652 random_sample_batches=args.random_sample_batches, 614 sample_batches=args.sample_batches,
653 stable_sample_batches=args.stable_sample_batches,
654 seed=args.seed or torch.random.seed() 615 seed=args.seed or torch.random.seed()
655 ) 616 )
656 617
diff --git a/textual_inversion.py b/textual_inversion.py
index 285aa0a..00d460f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -207,16 +207,10 @@ def parse_args():
207 help="Size of sample images", 207 help="Size of sample images",
208 ) 208 )
209 parser.add_argument( 209 parser.add_argument(
210 "--stable_sample_batches", 210 "--sample_batches",
211 type=int, 211 type=int,
212 default=1, 212 default=1,
213 help="Number of fixed seed sample batches to generate per checkpoint", 213 help="Number of sample batches to generate per checkpoint",
214 )
215 parser.add_argument(
216 "--random_sample_batches",
217 type=int,
218 default=1,
219 help="Number of random seed sample batches to generate per checkpoint",
220 ) 214 )
221 parser.add_argument( 215 parser.add_argument(
222 "--sample_batch_size", 216 "--sample_batch_size",
@@ -319,9 +313,8 @@ class Checkpointer:
319 placeholder_token_id, 313 placeholder_token_id,
320 output_dir, 314 output_dir,
321 sample_image_size, 315 sample_image_size,
322 random_sample_batches, 316 sample_batches,
323 sample_batch_size, 317 sample_batch_size,
324 stable_sample_batches,
325 seed 318 seed
326 ): 319 ):
327 self.datamodule = datamodule 320 self.datamodule = datamodule
@@ -334,9 +327,8 @@ class Checkpointer:
334 self.output_dir = output_dir 327 self.output_dir = output_dir
335 self.sample_image_size = sample_image_size 328 self.sample_image_size = sample_image_size
336 self.seed = seed 329 self.seed = seed
337 self.random_sample_batches = random_sample_batches 330 self.sample_batches = sample_batches
338 self.sample_batch_size = sample_batch_size 331 self.sample_batch_size = sample_batch_size
339 self.stable_sample_batches = stable_sample_batches
340 332
341 @torch.no_grad() 333 @torch.no_grad()
342 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): 334 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
@@ -385,63 +377,33 @@ class Checkpointer:
385 train_data = self.datamodule.train_dataloader() 377 train_data = self.datamodule.train_dataloader()
386 val_data = self.datamodule.val_dataloader() 378 val_data = self.datamodule.val_dataloader()
387 379
388 if self.stable_sample_batches > 0: 380 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
389 stable_latents = torch.randn( 381 stable_latents = torch.randn(
390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 382 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
391 device=pipeline.device, 383 device=pipeline.device,
392 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), 384 generator=generator,
393 ) 385 )
394
395 all_samples = []
396 file_path = samples_path.joinpath("stable", f"step_{step}.png")
397 file_path.parent.mkdir(parents=True, exist_ok=True)
398
399 data_enum = enumerate(val_data)
400
401 # Generate and save stable samples
402 for i in range(0, self.stable_sample_batches):
403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
405
406 with self.accelerator.autocast():
407 samples = pipeline(
408 prompt=prompt,
409 height=self.sample_image_size,
410 latents=stable_latents[:len(prompt)],
411 width=self.sample_image_size,
412 guidance_scale=guidance_scale,
413 eta=eta,
414 num_inference_steps=num_inference_steps,
415 output_type='pil'
416 )["sample"]
417
418 all_samples += samples
419
420 del samples
421
422 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
423 image_grid.save(file_path)
424
425 del all_samples
426 del image_grid
427 del stable_latents
428 386
429 for data, pool in [(val_data, "val"), (train_data, "train")]: 387 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
430 all_samples = [] 388 all_samples = []
431 file_path = samples_path.joinpath(pool, f"step_{step}.png") 389 file_path = samples_path.joinpath(pool, f"step_{step}.png")
432 file_path.parent.mkdir(parents=True, exist_ok=True) 390 file_path.parent.mkdir(parents=True, exist_ok=True)
433 391
434 data_enum = enumerate(data) 392 data_enum = enumerate(data)
435 393
436 for i in range(0, self.random_sample_batches): 394 for i in range(self.sample_batches):
437 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 395 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
438 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 396 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
397 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
439 398
440 with self.accelerator.autocast(): 399 with self.accelerator.autocast():
441 samples = pipeline( 400 samples = pipeline(
442 prompt=prompt, 401 prompt=prompt,
402 negative_prompt=nprompt,
443 height=self.sample_image_size, 403 height=self.sample_image_size,
444 width=self.sample_image_size, 404 width=self.sample_image_size,
405 latents=latents[:len(prompt)] if latents is not None else None,
406 generator=generator if latents is not None else None,
445 guidance_scale=guidance_scale, 407 guidance_scale=guidance_scale,
446 eta=eta, 408 eta=eta,
447 num_inference_steps=num_inference_steps, 409 num_inference_steps=num_inference_steps,
@@ -452,7 +414,7 @@ class Checkpointer:
452 414
453 del samples 415 del samples
454 416
455 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 417 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
456 image_grid.save(file_path) 418 image_grid.save(file_path)
457 419
458 del all_samples 420 del all_samples
@@ -461,6 +423,8 @@ class Checkpointer:
461 del unwrapped 423 del unwrapped
462 del scheduler 424 del scheduler
463 del pipeline 425 del pipeline
426 del generator
427 del stable_latents
464 428
465 if torch.cuda.is_available(): 429 if torch.cuda.is_available():
466 torch.cuda.empty_cache() 430 torch.cuda.empty_cache()
@@ -603,7 +567,7 @@ def main():
603 placeholder_token=args.placeholder_token, 567 placeholder_token=args.placeholder_token,
604 repeats=args.repeats, 568 repeats=args.repeats,
605 center_crop=args.center_crop, 569 center_crop=args.center_crop,
606 valid_set_size=args.sample_batch_size*args.stable_sample_batches 570 valid_set_size=args.sample_batch_size*args.sample_batches
607 ) 571 )
608 572
609 datamodule.prepare_data() 573 datamodule.prepare_data()
@@ -623,8 +587,7 @@ def main():
623 output_dir=basepath, 587 output_dir=basepath,
624 sample_image_size=args.sample_image_size, 588 sample_image_size=args.sample_image_size,
625 sample_batch_size=args.sample_batch_size, 589 sample_batch_size=args.sample_batch_size,
626 random_sample_batches=args.random_sample_batches, 590 sample_batches=args.sample_batches,
627 stable_sample_batches=args.stable_sample_batches,
628 seed=args.seed or torch.random.seed() 591 seed=args.seed or torch.random.seed()
629 ) 592 )
630 593