summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
committerVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
commit6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch)
treef490b4794366e78f7b079eb04de1c7c00e17d34a /dreambooth.py
parentFix small details (diff)
downloadtextual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py75
1 files changed, 45 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py
index a26bea7..7b61c45 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -3,6 +3,7 @@ import math
3import os 3import os
4import datetime 4import datetime
5import logging 5import logging
6import json
6from pathlib import Path 7from pathlib import Path
7 8
8import numpy as np 9import numpy as np
@@ -21,7 +22,6 @@ from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 23from slugify import slugify
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24import json
25 25
26from data.csv import CSVDataModule 26from data.csv import CSVDataModule
27 27
@@ -68,7 +68,7 @@ def parse_args():
68 parser.add_argument( 68 parser.add_argument(
69 "--num_class_images", 69 "--num_class_images",
70 type=int, 70 type=int,
71 default=4, 71 default=200,
72 help="How many class images to generate per training image." 72 help="How many class images to generate per training image."
73 ) 73 )
74 parser.add_argument( 74 parser.add_argument(
@@ -140,7 +140,7 @@ def parse_args():
140 parser.add_argument( 140 parser.add_argument(
141 "--lr_scheduler", 141 "--lr_scheduler",
142 type=str, 142 type=str,
143 default="constant", 143 default="linear",
144 help=( 144 help=(
145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
146 ' "constant", "constant_with_warmup"]' 146 ' "constant", "constant_with_warmup"]'
@@ -199,6 +199,12 @@ def parse_args():
199 help="For distributed training: local_rank" 199 help="For distributed training: local_rank"
200 ) 200 )
201 parser.add_argument( 201 parser.add_argument(
202 "--sample_frequency",
203 type=int,
204 default=100,
205 help="How often to save a checkpoint and sample image",
206 )
207 parser.add_argument(
202 "--sample_image_size", 208 "--sample_image_size",
203 type=int, 209 type=int,
204 default=512, 210 default=512,
@@ -366,20 +372,20 @@ class Checkpointer:
366 generator=generator, 372 generator=generator,
367 ) 373 )
368 374
369 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 375 with torch.inference_mode():
370 all_samples = [] 376 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
371 file_path = samples_path.joinpath(pool, f"step_{step}.png") 377 all_samples = []
372 file_path.parent.mkdir(parents=True, exist_ok=True) 378 file_path = samples_path.joinpath(pool, f"step_{step}.png")
379 file_path.parent.mkdir(parents=True, exist_ok=True)
373 380
374 data_enum = enumerate(data) 381 data_enum = enumerate(data)
375 382
376 for i in range(self.sample_batches): 383 for i in range(self.sample_batches):
377 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 384 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
378 prompt = [prompt.format(self.instance_identifier) 385 prompt = [prompt.format(self.instance_identifier)
379 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 386 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
380 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 387 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
381 388
382 with self.accelerator.autocast():
383 samples = pipeline( 389 samples = pipeline(
384 prompt=prompt, 390 prompt=prompt,
385 negative_prompt=nprompt, 391 negative_prompt=nprompt,
@@ -393,15 +399,15 @@ class Checkpointer:
393 output_type='pil' 399 output_type='pil'
394 )["sample"] 400 )["sample"]
395 401
396 all_samples += samples 402 all_samples += samples
397 403
398 del samples 404 del samples
399 405
400 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 406 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
401 image_grid.save(file_path) 407 image_grid.save(file_path)
402 408
403 del all_samples 409 del all_samples
404 del image_grid 410 del image_grid
405 411
406 del unwrapped 412 del unwrapped
407 del scheduler 413 del scheduler
@@ -538,7 +544,7 @@ def main():
538 datamodule.setup() 544 datamodule.setup()
539 545
540 if args.num_class_images != 0: 546 if args.num_class_images != 0:
541 missing_data = [item for item in datamodule.data if not item[1].exists()] 547 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
542 548
543 if len(missing_data) != 0: 549 if len(missing_data) != 0:
544 batched_data = [missing_data[i:i+args.sample_batch_size] 550 batched_data = [missing_data[i:i+args.sample_batch_size]
@@ -558,20 +564,20 @@ def main():
558 pipeline.enable_attention_slicing() 564 pipeline.enable_attention_slicing()
559 pipeline.set_progress_bar_config(dynamic_ncols=True) 565 pipeline.set_progress_bar_config(dynamic_ncols=True)
560 566
561 for batch in batched_data: 567 with torch.inference_mode():
562 image_name = [p[1] for p in batch] 568 for batch in batched_data:
563 prompt = [p[2].format(args.class_identifier) for p in batch] 569 image_name = [p.class_image_path for p in batch]
564 nprompt = [p[3] for p in batch] 570 prompt = [p.prompt.format(args.class_identifier) for p in batch]
571 nprompt = [p.nprompt for p in batch]
565 572
566 with accelerator.autocast():
567 images = pipeline( 573 images = pipeline(
568 prompt=prompt, 574 prompt=prompt,
569 negative_prompt=nprompt, 575 negative_prompt=nprompt,
570 num_inference_steps=args.sample_steps 576 num_inference_steps=args.sample_steps
571 ).images 577 ).images
572 578
573 for i, image in enumerate(images): 579 for i, image in enumerate(images):
574 image.save(image_name[i]) 580 image.save(image_name[i])
575 581
576 del pipeline 582 del pipeline
577 583
@@ -677,6 +683,8 @@ def main():
677 unet.train() 683 unet.train()
678 train_loss = 0.0 684 train_loss = 0.0
679 685
686 sample_checkpoint = False
687
680 for step, batch in enumerate(train_dataloader): 688 for step, batch in enumerate(train_dataloader):
681 with accelerator.accumulate(unet): 689 with accelerator.accumulate(unet):
682 # Convert images to latent space 690 # Convert images to latent space
@@ -737,6 +745,9 @@ def main():
737 745
738 global_step += 1 746 global_step += 1
739 747
748 if global_step % args.sample_frequency == 0:
749 sample_checkpoint = True
750
740 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 751 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
741 local_progress_bar.set_postfix(**logs) 752 local_progress_bar.set_postfix(**logs)
742 753
@@ -783,7 +794,11 @@ def main():
783 794
784 val_loss /= len(val_dataloader) 795 val_loss /= len(val_dataloader)
785 796
786 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 797 accelerator.log({
798 "train/loss": train_loss,
799 "val/loss": val_loss,
800 "lr": lr_scheduler.get_last_lr()[0]
801 }, step=global_step)
787 802
788 local_progress_bar.clear() 803 local_progress_bar.clear()
789 global_progress_bar.clear() 804 global_progress_bar.clear()
@@ -792,7 +807,7 @@ def main():
792 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 807 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
793 min_val_loss = val_loss 808 min_val_loss = val_loss
794 809
795 if accelerator.is_main_process: 810 if sample_checkpoint and accelerator.is_main_process:
796 checkpointer.save_samples( 811 checkpointer.save_samples(
797 global_step, 812 global_step,
798 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 813 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)