diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 75 |
1 files changed, 45 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py index a26bea7..7b61c45 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -3,6 +3,7 @@ import math | |||
3 | import os | 3 | import os |
4 | import datetime | 4 | import datetime |
5 | import logging | 5 | import logging |
6 | import json | ||
6 | from pathlib import Path | 7 | from pathlib import Path |
7 | 8 | ||
8 | import numpy as np | 9 | import numpy as np |
@@ -21,7 +22,6 @@ from tqdm.auto import tqdm | |||
21 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 23 | from slugify import slugify |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | import json | ||
25 | 25 | ||
26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule |
27 | 27 | ||
@@ -68,7 +68,7 @@ def parse_args(): | |||
68 | parser.add_argument( | 68 | parser.add_argument( |
69 | "--num_class_images", | 69 | "--num_class_images", |
70 | type=int, | 70 | type=int, |
71 | default=4, | 71 | default=200, |
72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." |
73 | ) | 73 | ) |
74 | parser.add_argument( | 74 | parser.add_argument( |
@@ -140,7 +140,7 @@ def parse_args(): | |||
140 | parser.add_argument( | 140 | parser.add_argument( |
141 | "--lr_scheduler", | 141 | "--lr_scheduler", |
142 | type=str, | 142 | type=str, |
143 | default="constant", | 143 | default="linear", |
144 | help=( | 144 | help=( |
145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
146 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' |
@@ -199,6 +199,12 @@ def parse_args(): | |||
199 | help="For distributed training: local_rank" | 199 | help="For distributed training: local_rank" |
200 | ) | 200 | ) |
201 | parser.add_argument( | 201 | parser.add_argument( |
202 | "--sample_frequency", | ||
203 | type=int, | ||
204 | default=100, | ||
205 | help="How often to save a checkpoint and sample image", | ||
206 | ) | ||
207 | parser.add_argument( | ||
202 | "--sample_image_size", | 208 | "--sample_image_size", |
203 | type=int, | 209 | type=int, |
204 | default=512, | 210 | default=512, |
@@ -366,20 +372,20 @@ class Checkpointer: | |||
366 | generator=generator, | 372 | generator=generator, |
367 | ) | 373 | ) |
368 | 374 | ||
369 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 375 | with torch.inference_mode(): |
370 | all_samples = [] | 376 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
371 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 377 | all_samples = [] |
372 | file_path.parent.mkdir(parents=True, exist_ok=True) | 378 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
379 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
373 | 380 | ||
374 | data_enum = enumerate(data) | 381 | data_enum = enumerate(data) |
375 | 382 | ||
376 | for i in range(self.sample_batches): | 383 | for i in range(self.sample_batches): |
377 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 384 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
378 | prompt = [prompt.format(self.instance_identifier) | 385 | prompt = [prompt.format(self.instance_identifier) |
379 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 386 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
380 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 387 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] |
381 | 388 | ||
382 | with self.accelerator.autocast(): | ||
383 | samples = pipeline( | 389 | samples = pipeline( |
384 | prompt=prompt, | 390 | prompt=prompt, |
385 | negative_prompt=nprompt, | 391 | negative_prompt=nprompt, |
@@ -393,15 +399,15 @@ class Checkpointer: | |||
393 | output_type='pil' | 399 | output_type='pil' |
394 | )["sample"] | 400 | )["sample"] |
395 | 401 | ||
396 | all_samples += samples | 402 | all_samples += samples |
397 | 403 | ||
398 | del samples | 404 | del samples |
399 | 405 | ||
400 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 406 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
401 | image_grid.save(file_path) | 407 | image_grid.save(file_path) |
402 | 408 | ||
403 | del all_samples | 409 | del all_samples |
404 | del image_grid | 410 | del image_grid |
405 | 411 | ||
406 | del unwrapped | 412 | del unwrapped |
407 | del scheduler | 413 | del scheduler |
@@ -538,7 +544,7 @@ def main(): | |||
538 | datamodule.setup() | 544 | datamodule.setup() |
539 | 545 | ||
540 | if args.num_class_images != 0: | 546 | if args.num_class_images != 0: |
541 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 547 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
542 | 548 | ||
543 | if len(missing_data) != 0: | 549 | if len(missing_data) != 0: |
544 | batched_data = [missing_data[i:i+args.sample_batch_size] | 550 | batched_data = [missing_data[i:i+args.sample_batch_size] |
@@ -558,20 +564,20 @@ def main(): | |||
558 | pipeline.enable_attention_slicing() | 564 | pipeline.enable_attention_slicing() |
559 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 565 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
560 | 566 | ||
561 | for batch in batched_data: | 567 | with torch.inference_mode(): |
562 | image_name = [p[1] for p in batch] | 568 | for batch in batched_data: |
563 | prompt = [p[2].format(args.class_identifier) for p in batch] | 569 | image_name = [p.class_image_path for p in batch] |
564 | nprompt = [p[3] for p in batch] | 570 | prompt = [p.prompt.format(args.class_identifier) for p in batch] |
571 | nprompt = [p.nprompt for p in batch] | ||
565 | 572 | ||
566 | with accelerator.autocast(): | ||
567 | images = pipeline( | 573 | images = pipeline( |
568 | prompt=prompt, | 574 | prompt=prompt, |
569 | negative_prompt=nprompt, | 575 | negative_prompt=nprompt, |
570 | num_inference_steps=args.sample_steps | 576 | num_inference_steps=args.sample_steps |
571 | ).images | 577 | ).images |
572 | 578 | ||
573 | for i, image in enumerate(images): | 579 | for i, image in enumerate(images): |
574 | image.save(image_name[i]) | 580 | image.save(image_name[i]) |
575 | 581 | ||
576 | del pipeline | 582 | del pipeline |
577 | 583 | ||
@@ -677,6 +683,8 @@ def main(): | |||
677 | unet.train() | 683 | unet.train() |
678 | train_loss = 0.0 | 684 | train_loss = 0.0 |
679 | 685 | ||
686 | sample_checkpoint = False | ||
687 | |||
680 | for step, batch in enumerate(train_dataloader): | 688 | for step, batch in enumerate(train_dataloader): |
681 | with accelerator.accumulate(unet): | 689 | with accelerator.accumulate(unet): |
682 | # Convert images to latent space | 690 | # Convert images to latent space |
@@ -737,6 +745,9 @@ def main(): | |||
737 | 745 | ||
738 | global_step += 1 | 746 | global_step += 1 |
739 | 747 | ||
748 | if global_step % args.sample_frequency == 0: | ||
749 | sample_checkpoint = True | ||
750 | |||
740 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 751 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
741 | local_progress_bar.set_postfix(**logs) | 752 | local_progress_bar.set_postfix(**logs) |
742 | 753 | ||
@@ -783,7 +794,11 @@ def main(): | |||
783 | 794 | ||
784 | val_loss /= len(val_dataloader) | 795 | val_loss /= len(val_dataloader) |
785 | 796 | ||
786 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 797 | accelerator.log({ |
798 | "train/loss": train_loss, | ||
799 | "val/loss": val_loss, | ||
800 | "lr": lr_scheduler.get_last_lr()[0] | ||
801 | }, step=global_step) | ||
787 | 802 | ||
788 | local_progress_bar.clear() | 803 | local_progress_bar.clear() |
789 | global_progress_bar.clear() | 804 | global_progress_bar.clear() |
@@ -792,7 +807,7 @@ def main(): | |||
792 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 807 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
793 | min_val_loss = val_loss | 808 | min_val_loss = val_loss |
794 | 809 | ||
795 | if accelerator.is_main_process: | 810 | if sample_checkpoint and accelerator.is_main_process: |
796 | checkpointer.save_samples( | 811 | checkpointer.save_samples( |
797 | global_step, | 812 | global_step, |
798 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 813 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |