summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 14:47:01 +0200
commitc90099f06e0b461660b326fb6d86b69d86e78289 (patch)
treedf4ce274eed8f2a89bbd12f1a19c685ceac58ff2 /dreambooth.py
parentFixed euler_a generator argument (diff)
downloadtextual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2
textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip
Added negative prompt support for training scripts
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py81
1 files changed, 21 insertions, 60 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 75602dc..5fbf172 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -191,16 +191,10 @@ def parse_args():
191 help="Size of sample images", 191 help="Size of sample images",
192 ) 192 )
193 parser.add_argument( 193 parser.add_argument(
194 "--stable_sample_batches", 194 "--sample_batches",
195 type=int, 195 type=int,
196 default=1, 196 default=1,
197 help="Number of fixed seed sample batches to generate per checkpoint", 197 help="Number of sample batches to generate per checkpoint",
198 )
199 parser.add_argument(
200 "--random_sample_batches",
201 type=int,
202 default=1,
203 help="Number of random seed sample batches to generate per checkpoint",
204 ) 198 )
205 parser.add_argument( 199 parser.add_argument(
206 "--sample_batch_size", 200 "--sample_batch_size",
@@ -331,9 +325,8 @@ class Checkpointer:
331 text_encoder, 325 text_encoder,
332 output_dir, 326 output_dir,
333 sample_image_size, 327 sample_image_size,
334 random_sample_batches, 328 sample_batches,
335 sample_batch_size, 329 sample_batch_size,
336 stable_sample_batches,
337 seed 330 seed
338 ): 331 ):
339 self.datamodule = datamodule 332 self.datamodule = datamodule
@@ -345,9 +338,8 @@ class Checkpointer:
345 self.output_dir = output_dir 338 self.output_dir = output_dir
346 self.sample_image_size = sample_image_size 339 self.sample_image_size = sample_image_size
347 self.seed = seed 340 self.seed = seed
348 self.random_sample_batches = random_sample_batches 341 self.sample_batches = sample_batches
349 self.sample_batch_size = sample_batch_size 342 self.sample_batch_size = sample_batch_size
350 self.stable_sample_batches = stable_sample_batches
351 343
352 @torch.no_grad() 344 @torch.no_grad()
353 def checkpoint(self): 345 def checkpoint(self):
@@ -396,63 +388,33 @@ class Checkpointer:
396 train_data = self.datamodule.train_dataloader() 388 train_data = self.datamodule.train_dataloader()
397 val_data = self.datamodule.val_dataloader() 389 val_data = self.datamodule.val_dataloader()
398 390
399 if self.stable_sample_batches > 0: 391 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
400 stable_latents = torch.randn( 392 stable_latents = torch.randn(
401 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 393 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
402 device=pipeline.device, 394 device=pipeline.device,
403 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), 395 generator=generator,
404 ) 396 )
405
406 all_samples = []
407 file_path = samples_path.joinpath("stable", f"step_{step}.png")
408 file_path.parent.mkdir(parents=True, exist_ok=True)
409
410 data_enum = enumerate(val_data)
411
412 # Generate and save stable samples
413 for i in range(0, self.stable_sample_batches):
414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
416
417 with self.accelerator.autocast():
418 samples = pipeline(
419 prompt=prompt,
420 height=self.sample_image_size,
421 latents=stable_latents[:len(prompt)],
422 width=self.sample_image_size,
423 guidance_scale=guidance_scale,
424 eta=eta,
425 num_inference_steps=num_inference_steps,
426 output_type='pil'
427 )["sample"]
428
429 all_samples += samples
430
431 del samples
432
433 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
434 image_grid.save(file_path)
435
436 del all_samples
437 del image_grid
438 del stable_latents
439 397
440 for data, pool in [(val_data, "val"), (train_data, "train")]: 398 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
441 all_samples = [] 399 all_samples = []
442 file_path = samples_path.joinpath(pool, f"step_{step}.png") 400 file_path = samples_path.joinpath(pool, f"step_{step}.png")
443 file_path.parent.mkdir(parents=True, exist_ok=True) 401 file_path.parent.mkdir(parents=True, exist_ok=True)
444 402
445 data_enum = enumerate(data) 403 data_enum = enumerate(data)
446 404
447 for i in range(0, self.random_sample_batches): 405 for i in range(self.sample_batches):
448 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 406 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
449 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 407 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
408 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
450 409
451 with self.accelerator.autocast(): 410 with self.accelerator.autocast():
452 samples = pipeline( 411 samples = pipeline(
453 prompt=prompt, 412 prompt=prompt,
413 negative_prompt=nprompt,
454 height=self.sample_image_size, 414 height=self.sample_image_size,
455 width=self.sample_image_size, 415 width=self.sample_image_size,
416 latents=latents[:len(prompt)] if latents is not None else None,
417 generator=generator if latents is not None else None,
456 guidance_scale=guidance_scale, 418 guidance_scale=guidance_scale,
457 eta=eta, 419 eta=eta,
458 num_inference_steps=num_inference_steps, 420 num_inference_steps=num_inference_steps,
@@ -463,7 +425,7 @@ class Checkpointer:
463 425
464 del samples 426 del samples
465 427
466 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 428 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
467 image_grid.save(file_path) 429 image_grid.save(file_path)
468 430
469 del all_samples 431 del all_samples
@@ -630,7 +592,7 @@ def main():
630 identifier=args.identifier, 592 identifier=args.identifier,
631 repeats=args.repeats, 593 repeats=args.repeats,
632 center_crop=args.center_crop, 594 center_crop=args.center_crop,
633 valid_set_size=args.sample_batch_size*args.stable_sample_batches, 595 valid_set_size=args.sample_batch_size*args.sample_batches,
634 collate_fn=collate_fn) 596 collate_fn=collate_fn)
635 597
636 datamodule.prepare_data() 598 datamodule.prepare_data()
@@ -649,8 +611,7 @@ def main():
649 output_dir=basepath, 611 output_dir=basepath,
650 sample_image_size=args.sample_image_size, 612 sample_image_size=args.sample_image_size,
651 sample_batch_size=args.sample_batch_size, 613 sample_batch_size=args.sample_batch_size,
652 random_sample_batches=args.random_sample_batches, 614 sample_batches=args.sample_batches,
653 stable_sample_batches=args.stable_sample_batches,
654 seed=args.seed or torch.random.seed() 615 seed=args.seed or torch.random.seed()
655 ) 616 )
656 617