diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-08 21:56:54 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-08 21:56:54 +0200 | 
| commit | 6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch) | |
| tree | f490b4794366e78f7b079eb04de1c7c00e17d34a /dreambooth.py | |
| parent | Fix small details (diff) | |
| download | textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2 textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 75 | 
1 files changed, 45 insertions, 30 deletions
| diff --git a/dreambooth.py b/dreambooth.py index a26bea7..7b61c45 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -3,6 +3,7 @@ import math | |||
| 3 | import os | 3 | import os | 
| 4 | import datetime | 4 | import datetime | 
| 5 | import logging | 5 | import logging | 
| 6 | import json | ||
| 6 | from pathlib import Path | 7 | from pathlib import Path | 
| 7 | 8 | ||
| 8 | import numpy as np | 9 | import numpy as np | 
| @@ -21,7 +22,6 @@ from tqdm.auto import tqdm | |||
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer | 
| 22 | from slugify import slugify | 23 | from slugify import slugify | 
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 24 | import json | ||
| 25 | 25 | ||
| 26 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule | 
| 27 | 27 | ||
| @@ -68,7 +68,7 @@ def parse_args(): | |||
| 68 | parser.add_argument( | 68 | parser.add_argument( | 
| 69 | "--num_class_images", | 69 | "--num_class_images", | 
| 70 | type=int, | 70 | type=int, | 
| 71 | default=4, | 71 | default=200, | 
| 72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." | 
| 73 | ) | 73 | ) | 
| 74 | parser.add_argument( | 74 | parser.add_argument( | 
| @@ -140,7 +140,7 @@ def parse_args(): | |||
| 140 | parser.add_argument( | 140 | parser.add_argument( | 
| 141 | "--lr_scheduler", | 141 | "--lr_scheduler", | 
| 142 | type=str, | 142 | type=str, | 
| 143 | default="constant", | 143 | default="linear", | 
| 144 | help=( | 144 | help=( | 
| 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 
| 146 | ' "constant", "constant_with_warmup"]' | 146 | ' "constant", "constant_with_warmup"]' | 
| @@ -199,6 +199,12 @@ def parse_args(): | |||
| 199 | help="For distributed training: local_rank" | 199 | help="For distributed training: local_rank" | 
| 200 | ) | 200 | ) | 
| 201 | parser.add_argument( | 201 | parser.add_argument( | 
| 202 | "--sample_frequency", | ||
| 203 | type=int, | ||
| 204 | default=100, | ||
| 205 | help="How often to save a checkpoint and sample image", | ||
| 206 | ) | ||
| 207 | parser.add_argument( | ||
| 202 | "--sample_image_size", | 208 | "--sample_image_size", | 
| 203 | type=int, | 209 | type=int, | 
| 204 | default=512, | 210 | default=512, | 
| @@ -366,20 +372,20 @@ class Checkpointer: | |||
| 366 | generator=generator, | 372 | generator=generator, | 
| 367 | ) | 373 | ) | 
| 368 | 374 | ||
| 369 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 375 | with torch.inference_mode(): | 
| 370 | all_samples = [] | 376 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 
| 371 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 377 | all_samples = [] | 
| 372 | file_path.parent.mkdir(parents=True, exist_ok=True) | 378 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 
| 379 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 373 | 380 | ||
| 374 | data_enum = enumerate(data) | 381 | data_enum = enumerate(data) | 
| 375 | 382 | ||
| 376 | for i in range(self.sample_batches): | 383 | for i in range(self.sample_batches): | 
| 377 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 384 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 
| 378 | prompt = [prompt.format(self.instance_identifier) | 385 | prompt = [prompt.format(self.instance_identifier) | 
| 379 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 386 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | 
| 380 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 387 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | 
| 381 | 388 | ||
| 382 | with self.accelerator.autocast(): | ||
| 383 | samples = pipeline( | 389 | samples = pipeline( | 
| 384 | prompt=prompt, | 390 | prompt=prompt, | 
| 385 | negative_prompt=nprompt, | 391 | negative_prompt=nprompt, | 
| @@ -393,15 +399,15 @@ class Checkpointer: | |||
| 393 | output_type='pil' | 399 | output_type='pil' | 
| 394 | )["sample"] | 400 | )["sample"] | 
| 395 | 401 | ||
| 396 | all_samples += samples | 402 | all_samples += samples | 
| 397 | 403 | ||
| 398 | del samples | 404 | del samples | 
| 399 | 405 | ||
| 400 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 406 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 
| 401 | image_grid.save(file_path) | 407 | image_grid.save(file_path) | 
| 402 | 408 | ||
| 403 | del all_samples | 409 | del all_samples | 
| 404 | del image_grid | 410 | del image_grid | 
| 405 | 411 | ||
| 406 | del unwrapped | 412 | del unwrapped | 
| 407 | del scheduler | 413 | del scheduler | 
| @@ -538,7 +544,7 @@ def main(): | |||
| 538 | datamodule.setup() | 544 | datamodule.setup() | 
| 539 | 545 | ||
| 540 | if args.num_class_images != 0: | 546 | if args.num_class_images != 0: | 
| 541 | missing_data = [item for item in datamodule.data if not item[1].exists()] | 547 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 
| 542 | 548 | ||
| 543 | if len(missing_data) != 0: | 549 | if len(missing_data) != 0: | 
| 544 | batched_data = [missing_data[i:i+args.sample_batch_size] | 550 | batched_data = [missing_data[i:i+args.sample_batch_size] | 
| @@ -558,20 +564,20 @@ def main(): | |||
| 558 | pipeline.enable_attention_slicing() | 564 | pipeline.enable_attention_slicing() | 
| 559 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 565 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 
| 560 | 566 | ||
| 561 | for batch in batched_data: | 567 | with torch.inference_mode(): | 
| 562 | image_name = [p[1] for p in batch] | 568 | for batch in batched_data: | 
| 563 | prompt = [p[2].format(args.class_identifier) for p in batch] | 569 | image_name = [p.class_image_path for p in batch] | 
| 564 | nprompt = [p[3] for p in batch] | 570 | prompt = [p.prompt.format(args.class_identifier) for p in batch] | 
| 571 | nprompt = [p.nprompt for p in batch] | ||
| 565 | 572 | ||
| 566 | with accelerator.autocast(): | ||
| 567 | images = pipeline( | 573 | images = pipeline( | 
| 568 | prompt=prompt, | 574 | prompt=prompt, | 
| 569 | negative_prompt=nprompt, | 575 | negative_prompt=nprompt, | 
| 570 | num_inference_steps=args.sample_steps | 576 | num_inference_steps=args.sample_steps | 
| 571 | ).images | 577 | ).images | 
| 572 | 578 | ||
| 573 | for i, image in enumerate(images): | 579 | for i, image in enumerate(images): | 
| 574 | image.save(image_name[i]) | 580 | image.save(image_name[i]) | 
| 575 | 581 | ||
| 576 | del pipeline | 582 | del pipeline | 
| 577 | 583 | ||
| @@ -677,6 +683,8 @@ def main(): | |||
| 677 | unet.train() | 683 | unet.train() | 
| 678 | train_loss = 0.0 | 684 | train_loss = 0.0 | 
| 679 | 685 | ||
| 686 | sample_checkpoint = False | ||
| 687 | |||
| 680 | for step, batch in enumerate(train_dataloader): | 688 | for step, batch in enumerate(train_dataloader): | 
| 681 | with accelerator.accumulate(unet): | 689 | with accelerator.accumulate(unet): | 
| 682 | # Convert images to latent space | 690 | # Convert images to latent space | 
| @@ -737,6 +745,9 @@ def main(): | |||
| 737 | 745 | ||
| 738 | global_step += 1 | 746 | global_step += 1 | 
| 739 | 747 | ||
| 748 | if global_step % args.sample_frequency == 0: | ||
| 749 | sample_checkpoint = True | ||
| 750 | |||
| 740 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 751 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 
| 741 | local_progress_bar.set_postfix(**logs) | 752 | local_progress_bar.set_postfix(**logs) | 
| 742 | 753 | ||
| @@ -783,7 +794,11 @@ def main(): | |||
| 783 | 794 | ||
| 784 | val_loss /= len(val_dataloader) | 795 | val_loss /= len(val_dataloader) | 
| 785 | 796 | ||
| 786 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 797 | accelerator.log({ | 
| 798 | "train/loss": train_loss, | ||
| 799 | "val/loss": val_loss, | ||
| 800 | "lr": lr_scheduler.get_last_lr()[0] | ||
| 801 | }, step=global_step) | ||
| 787 | 802 | ||
| 788 | local_progress_bar.clear() | 803 | local_progress_bar.clear() | 
| 789 | global_progress_bar.clear() | 804 | global_progress_bar.clear() | 
| @@ -792,7 +807,7 @@ def main(): | |||
| 792 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 807 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 
| 793 | min_val_loss = val_loss | 808 | min_val_loss = val_loss | 
| 794 | 809 | ||
| 795 | if accelerator.is_main_process: | 810 | if sample_checkpoint and accelerator.is_main_process: | 
| 796 | checkpointer.save_samples( | 811 | checkpointer.save_samples( | 
| 797 | global_step, | 812 | global_step, | 
| 798 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 813 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 
