diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 81 |
1 files changed, 21 insertions, 60 deletions
diff --git a/dreambooth.py b/dreambooth.py index 75602dc..5fbf172 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -191,16 +191,10 @@ def parse_args(): | |||
191 | help="Size of sample images", | 191 | help="Size of sample images", |
192 | ) | 192 | ) |
193 | parser.add_argument( | 193 | parser.add_argument( |
194 | "--stable_sample_batches", | 194 | "--sample_batches", |
195 | type=int, | 195 | type=int, |
196 | default=1, | 196 | default=1, |
197 | help="Number of fixed seed sample batches to generate per checkpoint", | 197 | help="Number of sample batches to generate per checkpoint", |
198 | ) | ||
199 | parser.add_argument( | ||
200 | "--random_sample_batches", | ||
201 | type=int, | ||
202 | default=1, | ||
203 | help="Number of random seed sample batches to generate per checkpoint", | ||
204 | ) | 198 | ) |
205 | parser.add_argument( | 199 | parser.add_argument( |
206 | "--sample_batch_size", | 200 | "--sample_batch_size", |
@@ -331,9 +325,8 @@ class Checkpointer: | |||
331 | text_encoder, | 325 | text_encoder, |
332 | output_dir, | 326 | output_dir, |
333 | sample_image_size, | 327 | sample_image_size, |
334 | random_sample_batches, | 328 | sample_batches, |
335 | sample_batch_size, | 329 | sample_batch_size, |
336 | stable_sample_batches, | ||
337 | seed | 330 | seed |
338 | ): | 331 | ): |
339 | self.datamodule = datamodule | 332 | self.datamodule = datamodule |
@@ -345,9 +338,8 @@ class Checkpointer: | |||
345 | self.output_dir = output_dir | 338 | self.output_dir = output_dir |
346 | self.sample_image_size = sample_image_size | 339 | self.sample_image_size = sample_image_size |
347 | self.seed = seed | 340 | self.seed = seed |
348 | self.random_sample_batches = random_sample_batches | 341 | self.sample_batches = sample_batches |
349 | self.sample_batch_size = sample_batch_size | 342 | self.sample_batch_size = sample_batch_size |
350 | self.stable_sample_batches = stable_sample_batches | ||
351 | 343 | ||
352 | @torch.no_grad() | 344 | @torch.no_grad() |
353 | def checkpoint(self): | 345 | def checkpoint(self): |
@@ -396,63 +388,33 @@ class Checkpointer: | |||
396 | train_data = self.datamodule.train_dataloader() | 388 | train_data = self.datamodule.train_dataloader() |
397 | val_data = self.datamodule.val_dataloader() | 389 | val_data = self.datamodule.val_dataloader() |
398 | 390 | ||
399 | if self.stable_sample_batches > 0: | 391 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
400 | stable_latents = torch.randn( | 392 | stable_latents = torch.randn( |
401 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 393 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
402 | device=pipeline.device, | 394 | device=pipeline.device, |
403 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 395 | generator=generator, |
404 | ) | 396 | ) |
405 | |||
406 | all_samples = [] | ||
407 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
409 | |||
410 | data_enum = enumerate(val_data) | ||
411 | |||
412 | # Generate and save stable samples | ||
413 | for i in range(0, self.stable_sample_batches): | ||
414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
416 | |||
417 | with self.accelerator.autocast(): | ||
418 | samples = pipeline( | ||
419 | prompt=prompt, | ||
420 | height=self.sample_image_size, | ||
421 | latents=stable_latents[:len(prompt)], | ||
422 | width=self.sample_image_size, | ||
423 | guidance_scale=guidance_scale, | ||
424 | eta=eta, | ||
425 | num_inference_steps=num_inference_steps, | ||
426 | output_type='pil' | ||
427 | )["sample"] | ||
428 | |||
429 | all_samples += samples | ||
430 | |||
431 | del samples | ||
432 | |||
433 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
434 | image_grid.save(file_path) | ||
435 | |||
436 | del all_samples | ||
437 | del image_grid | ||
438 | del stable_latents | ||
439 | 397 | ||
440 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 398 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
441 | all_samples = [] | 399 | all_samples = [] |
442 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 400 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
443 | file_path.parent.mkdir(parents=True, exist_ok=True) | 401 | file_path.parent.mkdir(parents=True, exist_ok=True) |
444 | 402 | ||
445 | data_enum = enumerate(data) | 403 | data_enum = enumerate(data) |
446 | 404 | ||
447 | for i in range(0, self.random_sample_batches): | 405 | for i in range(self.sample_batches): |
448 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 406 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
449 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 407 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
408 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
450 | 409 | ||
451 | with self.accelerator.autocast(): | 410 | with self.accelerator.autocast(): |
452 | samples = pipeline( | 411 | samples = pipeline( |
453 | prompt=prompt, | 412 | prompt=prompt, |
413 | negative_prompt=nprompt, | ||
454 | height=self.sample_image_size, | 414 | height=self.sample_image_size, |
455 | width=self.sample_image_size, | 415 | width=self.sample_image_size, |
416 | latents=latents[:len(prompt)] if latents is not None else None, | ||
417 | generator=generator if latents is not None else None, | ||
456 | guidance_scale=guidance_scale, | 418 | guidance_scale=guidance_scale, |
457 | eta=eta, | 419 | eta=eta, |
458 | num_inference_steps=num_inference_steps, | 420 | num_inference_steps=num_inference_steps, |
@@ -463,7 +425,7 @@ class Checkpointer: | |||
463 | 425 | ||
464 | del samples | 426 | del samples |
465 | 427 | ||
466 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 428 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
467 | image_grid.save(file_path) | 429 | image_grid.save(file_path) |
468 | 430 | ||
469 | del all_samples | 431 | del all_samples |
@@ -630,7 +592,7 @@ def main(): | |||
630 | identifier=args.identifier, | 592 | identifier=args.identifier, |
631 | repeats=args.repeats, | 593 | repeats=args.repeats, |
632 | center_crop=args.center_crop, | 594 | center_crop=args.center_crop, |
633 | valid_set_size=args.sample_batch_size*args.stable_sample_batches, | 595 | valid_set_size=args.sample_batch_size*args.sample_batches, |
634 | collate_fn=collate_fn) | 596 | collate_fn=collate_fn) |
635 | 597 | ||
636 | datamodule.prepare_data() | 598 | datamodule.prepare_data() |
@@ -649,8 +611,7 @@ def main(): | |||
649 | output_dir=basepath, | 611 | output_dir=basepath, |
650 | sample_image_size=args.sample_image_size, | 612 | sample_image_size=args.sample_image_size, |
651 | sample_batch_size=args.sample_batch_size, | 613 | sample_batch_size=args.sample_batch_size, |
652 | random_sample_batches=args.random_sample_batches, | 614 | sample_batches=args.sample_batches, |
653 | stable_sample_batches=args.stable_sample_batches, | ||
654 | seed=args.seed or torch.random.seed() | 615 | seed=args.seed or torch.random.seed() |
655 | ) | 616 | ) |
656 | 617 | ||