summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
commitc90099f06e0b461660b326fb6d86b69d86e78289 (patch)
treedf4ce274eed8f2a89bbd12f1a19c685ceac58ff2 /textual_inversion.py
parentFixed euler_a generator argument (diff)
downloadtextual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip
Added negative prompt support for training scripts
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py83
1 files changed, 23 insertions, 60 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 285aa0a..00d460f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -207,16 +207,10 @@ def parse_args():
207 help="Size of sample images", 207 help="Size of sample images",
208 ) 208 )
209 parser.add_argument( 209 parser.add_argument(
210 "--stable_sample_batches", 210 "--sample_batches",
211 type=int, 211 type=int,
212 default=1, 212 default=1,
213 help="Number of fixed seed sample batches to generate per checkpoint", 213 help="Number of sample batches to generate per checkpoint",
214 )
215 parser.add_argument(
216 "--random_sample_batches",
217 type=int,
218 default=1,
219 help="Number of random seed sample batches to generate per checkpoint",
220 ) 214 )
221 parser.add_argument( 215 parser.add_argument(
222 "--sample_batch_size", 216 "--sample_batch_size",
@@ -319,9 +313,8 @@ class Checkpointer:
319 placeholder_token_id, 313 placeholder_token_id,
320 output_dir, 314 output_dir,
321 sample_image_size, 315 sample_image_size,
322 random_sample_batches, 316 sample_batches,
323 sample_batch_size, 317 sample_batch_size,
324 stable_sample_batches,
325 seed 318 seed
326 ): 319 ):
327 self.datamodule = datamodule 320 self.datamodule = datamodule
@@ -334,9 +327,8 @@ class Checkpointer:
334 self.output_dir = output_dir 327 self.output_dir = output_dir
335 self.sample_image_size = sample_image_size 328 self.sample_image_size = sample_image_size
336 self.seed = seed 329 self.seed = seed
337 self.random_sample_batches = random_sample_batches 330 self.sample_batches = sample_batches
338 self.sample_batch_size = sample_batch_size 331 self.sample_batch_size = sample_batch_size
339 self.stable_sample_batches = stable_sample_batches
340 332
341 @torch.no_grad() 333 @torch.no_grad()
342 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): 334 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
@@ -385,63 +377,33 @@ class Checkpointer:
385 train_data = self.datamodule.train_dataloader() 377 train_data = self.datamodule.train_dataloader()
386 val_data = self.datamodule.val_dataloader() 378 val_data = self.datamodule.val_dataloader()
387 379
388 if self.stable_sample_batches > 0: 380 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
389 stable_latents = torch.randn( 381 stable_latents = torch.randn(
390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 382 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
391 device=pipeline.device, 383 device=pipeline.device,
392 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), 384 generator=generator,
393 ) 385 )
394
395 all_samples = []
396 file_path = samples_path.joinpath("stable", f"step_{step}.png")
397 file_path.parent.mkdir(parents=True, exist_ok=True)
398
399 data_enum = enumerate(val_data)
400
401 # Generate and save stable samples
402 for i in range(0, self.stable_sample_batches):
403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
405
406 with self.accelerator.autocast():
407 samples = pipeline(
408 prompt=prompt,
409 height=self.sample_image_size,
410 latents=stable_latents[:len(prompt)],
411 width=self.sample_image_size,
412 guidance_scale=guidance_scale,
413 eta=eta,
414 num_inference_steps=num_inference_steps,
415 output_type='pil'
416 )["sample"]
417
418 all_samples += samples
419
420 del samples
421
422 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
423 image_grid.save(file_path)
424
425 del all_samples
426 del image_grid
427 del stable_latents
428 386
429 for data, pool in [(val_data, "val"), (train_data, "train")]: 387 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
430 all_samples = [] 388 all_samples = []
431 file_path = samples_path.joinpath(pool, f"step_{step}.png") 389 file_path = samples_path.joinpath(pool, f"step_{step}.png")
432 file_path.parent.mkdir(parents=True, exist_ok=True) 390 file_path.parent.mkdir(parents=True, exist_ok=True)
433 391
434 data_enum = enumerate(data) 392 data_enum = enumerate(data)
435 393
436 for i in range(0, self.random_sample_batches): 394 for i in range(self.sample_batches):
437 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 395 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
438 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 396 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
397 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
439 398
440 with self.accelerator.autocast(): 399 with self.accelerator.autocast():
441 samples = pipeline( 400 samples = pipeline(
442 prompt=prompt, 401 prompt=prompt,
402 negative_prompt=nprompt,
443 height=self.sample_image_size, 403 height=self.sample_image_size,
444 width=self.sample_image_size, 404 width=self.sample_image_size,
405 latents=latents[:len(prompt)] if latents is not None else None,
406 generator=generator if latents is not None else None,
445 guidance_scale=guidance_scale, 407 guidance_scale=guidance_scale,
446 eta=eta, 408 eta=eta,
447 num_inference_steps=num_inference_steps, 409 num_inference_steps=num_inference_steps,
@@ -452,7 +414,7 @@ class Checkpointer:
452 414
453 del samples 415 del samples
454 416
455 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 417 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
456 image_grid.save(file_path) 418 image_grid.save(file_path)
457 419
458 del all_samples 420 del all_samples
@@ -461,6 +423,8 @@ class Checkpointer:
461 del unwrapped 423 del unwrapped
462 del scheduler 424 del scheduler
463 del pipeline 425 del pipeline
426 del generator
427 del stable_latents
464 428
465 if torch.cuda.is_available(): 429 if torch.cuda.is_available():
466 torch.cuda.empty_cache() 430 torch.cuda.empty_cache()
@@ -603,7 +567,7 @@ def main():
603 placeholder_token=args.placeholder_token, 567 placeholder_token=args.placeholder_token,
604 repeats=args.repeats, 568 repeats=args.repeats,
605 center_crop=args.center_crop, 569 center_crop=args.center_crop,
606 valid_set_size=args.sample_batch_size*args.stable_sample_batches 570 valid_set_size=args.sample_batch_size*args.sample_batches
607 ) 571 )
608 572
609 datamodule.prepare_data() 573 datamodule.prepare_data()
@@ -623,8 +587,7 @@ def main():
623 output_dir=basepath, 587 output_dir=basepath,
624 sample_image_size=args.sample_image_size, 588 sample_image_size=args.sample_image_size,
625 sample_batch_size=args.sample_batch_size, 589 sample_batch_size=args.sample_batch_size,
626 random_sample_batches=args.random_sample_batches, 590 sample_batches=args.sample_batches,
627 stable_sample_batches=args.stable_sample_batches,
628 seed=args.seed or torch.random.seed() 591 seed=args.seed or torch.random.seed()
629 ) 592 )
630 593