summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 17:10:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 17:10:06 +0100
commit3353ffb64c280a938a0f2513d13b716c1fca8c02 (patch)
treedbdc2ae1ddc5dc7758a2210e14e1fc9b18df7697 /train_dreambooth.py
parentMade aspect ratio bucketing configurable (diff)
downloadtextual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.tar.gz
textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.tar.bz2
textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.zip
Cleanup
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py265
1 files changed, 131 insertions, 134 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d265bcc..589af59 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -5,6 +5,7 @@ import datetime
5import logging 5import logging
6from pathlib import Path 6from pathlib import Path
7from functools import partial 7from functools import partial
8from contextlib import contextmanager, nullcontext
8 9
9import torch 10import torch
10import torch.utils.checkpoint 11import torch.utils.checkpoint
@@ -23,7 +24,7 @@ from slugify import slugify
23from util import load_config, load_embeddings_from_dir 24from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import VlpnDataModule, VlpnDataItem 26from data.csv import VlpnDataModule, VlpnDataItem
26from training.common import run_model 27from training.common import run_model, generate_class_images
27from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
28from training.lr import LRFinder 29from training.lr import LRFinder
29from training.util import AverageMeter, CheckpointerBase, save_args 30from training.util import AverageMeter, CheckpointerBase, save_args
@@ -216,7 +217,6 @@ def parse_args():
216 parser.add_argument( 217 parser.add_argument(
217 "--scale_lr", 218 "--scale_lr",
218 action="store_true", 219 action="store_true",
219 default=True,
220 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 220 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
221 ) 221 )
222 parser.add_argument( 222 parser.add_argument(
@@ -273,7 +273,6 @@ def parse_args():
273 parser.add_argument( 273 parser.add_argument(
274 "--use_ema", 274 "--use_ema",
275 action="store_true", 275 action="store_true",
276 default=True,
277 help="Whether to use EMA model." 276 help="Whether to use EMA model."
278 ) 277 )
279 parser.add_argument( 278 parser.add_argument(
@@ -294,7 +293,6 @@ def parse_args():
294 parser.add_argument( 293 parser.add_argument(
295 "--use_8bit_adam", 294 "--use_8bit_adam",
296 action="store_true", 295 action="store_true",
297 default=True,
298 help="Whether or not to use 8-bit Adam from bitsandbytes." 296 help="Whether or not to use 8-bit Adam from bitsandbytes."
299 ) 297 )
300 parser.add_argument( 298 parser.add_argument(
@@ -486,17 +484,20 @@ class Checkpointer(CheckpointerBase):
486 def save_model(self): 484 def save_model(self):
487 print("Saving model...") 485 print("Saving model...")
488 486
489 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) 487 unet = self.accelerator.unwrap_model(self.unet)
490 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 488 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
491 489
492 pipeline = VlpnStableDiffusion( 490 ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext()
493 text_encoder=text_encoder, 491
494 vae=self.vae, 492 with ema_context:
495 unet=unet, 493 pipeline = VlpnStableDiffusion(
496 tokenizer=self.tokenizer, 494 text_encoder=text_encoder,
497 scheduler=self.scheduler, 495 vae=self.vae,
498 ) 496 unet=unet,
499 pipeline.save_pretrained(self.output_dir.joinpath("model")) 497 tokenizer=self.tokenizer,
498 scheduler=self.scheduler,
499 )
500 pipeline.save_pretrained(self.output_dir.joinpath("model"))
500 501
501 del unet 502 del unet
502 del text_encoder 503 del text_encoder
@@ -507,28 +508,31 @@ class Checkpointer(CheckpointerBase):
507 508
508 @torch.no_grad() 509 @torch.no_grad()
509 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 510 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
510 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) 511 unet = self.accelerator.unwrap_model(self.unet)
511 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 512 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
512 513
513 orig_unet_dtype = unet.dtype 514 ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext()
514 orig_text_encoder_dtype = text_encoder.dtype
515 515
516 unet.to(dtype=self.weight_dtype) 516 with ema_context:
517 text_encoder.to(dtype=self.weight_dtype) 517 orig_unet_dtype = unet.dtype
518 orig_text_encoder_dtype = text_encoder.dtype
518 519
519 pipeline = VlpnStableDiffusion( 520 unet.to(dtype=self.weight_dtype)
520 text_encoder=text_encoder, 521 text_encoder.to(dtype=self.weight_dtype)
521 vae=self.vae, 522
522 unet=unet, 523 pipeline = VlpnStableDiffusion(
523 tokenizer=self.tokenizer, 524 text_encoder=text_encoder,
524 scheduler=self.scheduler, 525 vae=self.vae,
525 ).to(self.accelerator.device) 526 unet=unet,
526 pipeline.set_progress_bar_config(dynamic_ncols=True) 527 tokenizer=self.tokenizer,
528 scheduler=self.scheduler,
529 ).to(self.accelerator.device)
530 pipeline.set_progress_bar_config(dynamic_ncols=True)
527 531
528 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) 532 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
529 533
530 unet.to(dtype=orig_unet_dtype) 534 unet.to(dtype=orig_unet_dtype)
531 text_encoder.to(dtype=orig_text_encoder_dtype) 535 text_encoder.to(dtype=orig_text_encoder_dtype)
532 536
533 del unet 537 del unet
534 del text_encoder 538 del text_encoder
@@ -580,6 +584,7 @@ def main():
580 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') 584 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
581 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 585 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
582 args.pretrained_model_name_or_path, subfolder='scheduler') 586 args.pretrained_model_name_or_path, subfolder='scheduler')
587 ema_unet = None
583 588
584 vae.enable_slicing() 589 vae.enable_slicing()
585 vae.set_use_memory_efficient_attention_xformers(True) 590 vae.set_use_memory_efficient_attention_xformers(True)
@@ -589,14 +594,12 @@ def main():
589 unet.enable_gradient_checkpointing() 594 unet.enable_gradient_checkpointing()
590 text_encoder.gradient_checkpointing_enable() 595 text_encoder.gradient_checkpointing_enable()
591 596
592 ema_unet = None
593 if args.use_ema: 597 if args.use_ema:
594 ema_unet = EMAModel( 598 ema_unet = EMAModel(
595 unet, 599 unet.parameters(),
596 inv_gamma=args.ema_inv_gamma, 600 inv_gamma=args.ema_inv_gamma,
597 power=args.ema_power, 601 power=args.ema_power,
598 max_value=args.ema_max_decay, 602 max_value=args.ema_max_decay,
599 device=accelerator.device
600 ) 603 )
601 604
602 embeddings = patch_managed_embeddings(text_encoder) 605 embeddings = patch_managed_embeddings(text_encoder)
@@ -748,52 +751,27 @@ def main():
748 datamodule.prepare_data() 751 datamodule.prepare_data()
749 datamodule.setup() 752 datamodule.setup()
750 753
751 if args.num_class_images != 0: 754 train_dataloaders = datamodule.train_dataloaders
752 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] 755 val_dataloader = datamodule.val_dataloader
753
754 if len(missing_data) != 0:
755 batched_data = [
756 missing_data[i:i+args.sample_batch_size]
757 for i in range(0, len(missing_data), args.sample_batch_size)
758 ]
759
760 pipeline = VlpnStableDiffusion(
761 text_encoder=text_encoder,
762 vae=vae,
763 unet=unet,
764 tokenizer=tokenizer,
765 scheduler=checkpoint_scheduler,
766 ).to(accelerator.device)
767 pipeline.set_progress_bar_config(dynamic_ncols=True)
768
769 with torch.inference_mode():
770 for batch in batched_data:
771 image_name = [item.class_image_path for item in batch]
772 prompt = [item.cprompt for item in batch]
773 nprompt = [item.nprompt for item in batch]
774
775 images = pipeline(
776 prompt=prompt,
777 negative_prompt=nprompt,
778 height=args.sample_image_size,
779 width=args.sample_image_size,
780 num_inference_steps=args.sample_steps
781 ).images
782
783 for i, image in enumerate(images):
784 image.save(image_name[i])
785
786 del pipeline
787 756
788 if torch.cuda.is_available(): 757 if args.num_class_images != 0:
789 torch.cuda.empty_cache() 758 generate_class_images(
790 759 accelerator,
791 train_dataloader = datamodule.train_dataloader() 760 text_encoder,
792 val_dataloader = datamodule.val_dataloader() 761 vae,
762 unet,
763 tokenizer,
764 checkpoint_scheduler,
765 datamodule.data_train,
766 args.sample_batch_size,
767 args.sample_image_size,
768 args.sample_steps
769 )
793 770
794 # Scheduler and math around the number of training steps. 771 # Scheduler and math around the number of training steps.
795 overrode_max_train_steps = False 772 overrode_max_train_steps = False
796 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 773 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
774 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
797 if args.max_train_steps is None: 775 if args.max_train_steps is None:
798 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 776 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
799 overrode_max_train_steps = True 777 overrode_max_train_steps = True
@@ -838,8 +816,12 @@ def main():
838 # Keep text_encoder and vae in eval mode as we don't train these 816 # Keep text_encoder and vae in eval mode as we don't train these
839 vae.eval() 817 vae.eval()
840 818
819 if args.use_ema:
820 ema_unet.to(accelerator.device)
821
841 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 822 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
842 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 823 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
824 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
843 if overrode_max_train_steps: 825 if overrode_max_train_steps:
844 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 826 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
845 827
@@ -847,11 +829,25 @@ def main():
847 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 829 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
848 val_steps = num_val_steps_per_epoch * num_epochs 830 val_steps = num_val_steps_per_epoch * num_epochs
849 831
832 @contextmanager
850 def on_train(): 833 def on_train():
851 tokenizer.train() 834 try:
835 tokenizer.train()
836 yield
837 finally:
838 pass
852 839
840 @contextmanager
853 def on_eval(): 841 def on_eval():
854 tokenizer.eval() 842 try:
843 tokenizer.eval()
844
845 ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext()
846
847 with ema_context:
848 yield
849 finally:
850 pass
855 851
856 loop = partial( 852 loop = partial(
857 run_model, 853 run_model,
@@ -881,7 +877,7 @@ def main():
881 accelerator, 877 accelerator,
882 text_encoder, 878 text_encoder,
883 optimizer, 879 optimizer,
884 train_dataloader, 880 train_dataloaders[0],
885 val_dataloader, 881 val_dataloader,
886 loop, 882 loop,
887 on_train=tokenizer.train, 883 on_train=tokenizer.train,
@@ -962,88 +958,89 @@ def main():
962 text_encoder.train() 958 text_encoder.train()
963 elif epoch == args.train_text_encoder_epochs: 959 elif epoch == args.train_text_encoder_epochs:
964 text_encoder.requires_grad_(False) 960 text_encoder.requires_grad_(False)
965 on_train()
966 961
967 for step, batch in enumerate(train_dataloader): 962 with on_train():
968 with accelerator.accumulate(unet): 963 for train_dataloader in train_dataloaders:
969 loss, acc, bsz = loop(step, batch) 964 for step, batch in enumerate(train_dataloader):
965 with accelerator.accumulate(unet):
966 loss, acc, bsz = loop(step, batch)
970 967
971 accelerator.backward(loss) 968 accelerator.backward(loss)
972 969
973 if accelerator.sync_gradients: 970 if accelerator.sync_gradients:
974 params_to_clip = ( 971 params_to_clip = (
975 itertools.chain(unet.parameters(), text_encoder.parameters()) 972 itertools.chain(unet.parameters(), text_encoder.parameters())
976 if args.train_text_encoder and epoch < args.train_text_encoder_epochs 973 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
977 else unet.parameters() 974 else unet.parameters()
978 ) 975 )
979 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 976 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
980 977
981 optimizer.step() 978 optimizer.step()
982 if not accelerator.optimizer_step_was_skipped: 979 if not accelerator.optimizer_step_was_skipped:
983 lr_scheduler.step() 980 lr_scheduler.step()
984 if args.use_ema: 981 if args.use_ema:
985 ema_unet.step(unet) 982 ema_unet.step(unet.parameters())
986 optimizer.zero_grad(set_to_none=True) 983 optimizer.zero_grad(set_to_none=True)
987 984
988 avg_loss.update(loss.detach_(), bsz) 985 avg_loss.update(loss.detach_(), bsz)
989 avg_acc.update(acc.detach_(), bsz) 986 avg_acc.update(acc.detach_(), bsz)
990 987
991 # Checks if the accelerator has performed an optimization step behind the scenes 988 # Checks if the accelerator has performed an optimization step behind the scenes
992 if accelerator.sync_gradients: 989 if accelerator.sync_gradients:
993 local_progress_bar.update(1) 990 local_progress_bar.update(1)
994 global_progress_bar.update(1) 991 global_progress_bar.update(1)
995 992
996 global_step += 1 993 global_step += 1
997 994
998 logs = { 995 logs = {
999 "train/loss": avg_loss.avg.item(), 996 "train/loss": avg_loss.avg.item(),
1000 "train/acc": avg_acc.avg.item(), 997 "train/acc": avg_acc.avg.item(),
1001 "train/cur_loss": loss.item(), 998 "train/cur_loss": loss.item(),
1002 "train/cur_acc": acc.item(), 999 "train/cur_acc": acc.item(),
1003 "lr": lr_scheduler.get_last_lr()[0] 1000 "lr": lr_scheduler.get_last_lr()[0]
1004 } 1001 }
1005 if args.use_ema: 1002 if args.use_ema:
1006 logs["ema_decay"] = 1 - ema_unet.decay 1003 logs["ema_decay"] = 1 - ema_unet.decay
1007 1004
1008 accelerator.log(logs, step=global_step) 1005 accelerator.log(logs, step=global_step)
1009 1006
1010 local_progress_bar.set_postfix(**logs) 1007 local_progress_bar.set_postfix(**logs)
1011 1008
1012 if global_step >= args.max_train_steps: 1009 if global_step >= args.max_train_steps:
1013 break 1010 break
1014 1011
1015 accelerator.wait_for_everyone() 1012 accelerator.wait_for_everyone()
1016 1013
1017 unet.eval() 1014 unet.eval()
1018 text_encoder.eval() 1015 text_encoder.eval()
1019 on_eval()
1020 1016
1021 cur_loss_val = AverageMeter() 1017 cur_loss_val = AverageMeter()
1022 cur_acc_val = AverageMeter() 1018 cur_acc_val = AverageMeter()
1023 1019
1024 with torch.inference_mode(): 1020 with torch.inference_mode():
1025 for step, batch in enumerate(val_dataloader): 1021 with on_eval():
1026 loss, acc, bsz = loop(step, batch, True) 1022 for step, batch in enumerate(val_dataloader):
1023 loss, acc, bsz = loop(step, batch, True)
1027 1024
1028 loss = loss.detach_() 1025 loss = loss.detach_()
1029 acc = acc.detach_() 1026 acc = acc.detach_()
1030 1027
1031 cur_loss_val.update(loss, bsz) 1028 cur_loss_val.update(loss, bsz)
1032 cur_acc_val.update(acc, bsz) 1029 cur_acc_val.update(acc, bsz)
1033 1030
1034 avg_loss_val.update(loss, bsz) 1031 avg_loss_val.update(loss, bsz)
1035 avg_acc_val.update(acc, bsz) 1032 avg_acc_val.update(acc, bsz)
1036 1033
1037 local_progress_bar.update(1) 1034 local_progress_bar.update(1)
1038 global_progress_bar.update(1) 1035 global_progress_bar.update(1)
1039 1036
1040 logs = { 1037 logs = {
1041 "val/loss": avg_loss_val.avg.item(), 1038 "val/loss": avg_loss_val.avg.item(),
1042 "val/acc": avg_acc_val.avg.item(), 1039 "val/acc": avg_acc_val.avg.item(),
1043 "val/cur_loss": loss.item(), 1040 "val/cur_loss": loss.item(),
1044 "val/cur_acc": acc.item(), 1041 "val/cur_acc": acc.item(),
1045 } 1042 }
1046 local_progress_bar.set_postfix(**logs) 1043 local_progress_bar.set_postfix(**logs)
1047 1044
1048 logs["val/cur_loss"] = cur_loss_val.avg.item() 1045 logs["val/cur_loss"] = cur_loss_val.avg.item()
1049 logs["val/cur_acc"] = cur_acc_val.avg.item() 1046 logs["val/cur_acc"] = cur_acc_val.avg.item()