diff options
-rw-r--r-- | train_dreambooth.py | 265 | ||||
-rw-r--r-- | train_ti.py | 60 | ||||
-rw-r--r-- | training/common.py | 54 |
3 files changed, 201 insertions, 178 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d265bcc..589af59 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -5,6 +5,7 @@ import datetime | |||
5 | import logging | 5 | import logging |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | from functools import partial | 7 | from functools import partial |
8 | from contextlib import contextmanager, nullcontext | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
@@ -23,7 +24,7 @@ from slugify import slugify | |||
23 | from util import load_config, load_embeddings_from_dir | 24 | from util import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import VlpnDataModule, VlpnDataItem | 26 | from data.csv import VlpnDataModule, VlpnDataItem |
26 | from training.common import run_model | 27 | from training.common import run_model, generate_class_images |
27 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
28 | from training.lr import LRFinder | 29 | from training.lr import LRFinder |
29 | from training.util import AverageMeter, CheckpointerBase, save_args | 30 | from training.util import AverageMeter, CheckpointerBase, save_args |
@@ -216,7 +217,6 @@ def parse_args(): | |||
216 | parser.add_argument( | 217 | parser.add_argument( |
217 | "--scale_lr", | 218 | "--scale_lr", |
218 | action="store_true", | 219 | action="store_true", |
219 | default=True, | ||
220 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 220 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
221 | ) | 221 | ) |
222 | parser.add_argument( | 222 | parser.add_argument( |
@@ -273,7 +273,6 @@ def parse_args(): | |||
273 | parser.add_argument( | 273 | parser.add_argument( |
274 | "--use_ema", | 274 | "--use_ema", |
275 | action="store_true", | 275 | action="store_true", |
276 | default=True, | ||
277 | help="Whether to use EMA model." | 276 | help="Whether to use EMA model." |
278 | ) | 277 | ) |
279 | parser.add_argument( | 278 | parser.add_argument( |
@@ -294,7 +293,6 @@ def parse_args(): | |||
294 | parser.add_argument( | 293 | parser.add_argument( |
295 | "--use_8bit_adam", | 294 | "--use_8bit_adam", |
296 | action="store_true", | 295 | action="store_true", |
297 | default=True, | ||
298 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 296 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
299 | ) | 297 | ) |
300 | parser.add_argument( | 298 | parser.add_argument( |
@@ -486,17 +484,20 @@ class Checkpointer(CheckpointerBase): | |||
486 | def save_model(self): | 484 | def save_model(self): |
487 | print("Saving model...") | 485 | print("Saving model...") |
488 | 486 | ||
489 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 487 | unet = self.accelerator.unwrap_model(self.unet) |
490 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 488 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
491 | 489 | ||
492 | pipeline = VlpnStableDiffusion( | 490 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() |
493 | text_encoder=text_encoder, | 491 | |
494 | vae=self.vae, | 492 | with ema_context: |
495 | unet=unet, | 493 | pipeline = VlpnStableDiffusion( |
496 | tokenizer=self.tokenizer, | 494 | text_encoder=text_encoder, |
497 | scheduler=self.scheduler, | 495 | vae=self.vae, |
498 | ) | 496 | unet=unet, |
499 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 497 | tokenizer=self.tokenizer, |
498 | scheduler=self.scheduler, | ||
499 | ) | ||
500 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | ||
500 | 501 | ||
501 | del unet | 502 | del unet |
502 | del text_encoder | 503 | del text_encoder |
@@ -507,28 +508,31 @@ class Checkpointer(CheckpointerBase): | |||
507 | 508 | ||
508 | @torch.no_grad() | 509 | @torch.no_grad() |
509 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 510 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
510 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 511 | unet = self.accelerator.unwrap_model(self.unet) |
511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 512 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
512 | 513 | ||
513 | orig_unet_dtype = unet.dtype | 514 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() |
514 | orig_text_encoder_dtype = text_encoder.dtype | ||
515 | 515 | ||
516 | unet.to(dtype=self.weight_dtype) | 516 | with ema_context: |
517 | text_encoder.to(dtype=self.weight_dtype) | 517 | orig_unet_dtype = unet.dtype |
518 | orig_text_encoder_dtype = text_encoder.dtype | ||
518 | 519 | ||
519 | pipeline = VlpnStableDiffusion( | 520 | unet.to(dtype=self.weight_dtype) |
520 | text_encoder=text_encoder, | 521 | text_encoder.to(dtype=self.weight_dtype) |
521 | vae=self.vae, | 522 | |
522 | unet=unet, | 523 | pipeline = VlpnStableDiffusion( |
523 | tokenizer=self.tokenizer, | 524 | text_encoder=text_encoder, |
524 | scheduler=self.scheduler, | 525 | vae=self.vae, |
525 | ).to(self.accelerator.device) | 526 | unet=unet, |
526 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 527 | tokenizer=self.tokenizer, |
528 | scheduler=self.scheduler, | ||
529 | ).to(self.accelerator.device) | ||
530 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
527 | 531 | ||
528 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 532 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
529 | 533 | ||
530 | unet.to(dtype=orig_unet_dtype) | 534 | unet.to(dtype=orig_unet_dtype) |
531 | text_encoder.to(dtype=orig_text_encoder_dtype) | 535 | text_encoder.to(dtype=orig_text_encoder_dtype) |
532 | 536 | ||
533 | del unet | 537 | del unet |
534 | del text_encoder | 538 | del text_encoder |
@@ -580,6 +584,7 @@ def main(): | |||
580 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') | 584 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') |
581 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 585 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
582 | args.pretrained_model_name_or_path, subfolder='scheduler') | 586 | args.pretrained_model_name_or_path, subfolder='scheduler') |
587 | ema_unet = None | ||
583 | 588 | ||
584 | vae.enable_slicing() | 589 | vae.enable_slicing() |
585 | vae.set_use_memory_efficient_attention_xformers(True) | 590 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -589,14 +594,12 @@ def main(): | |||
589 | unet.enable_gradient_checkpointing() | 594 | unet.enable_gradient_checkpointing() |
590 | text_encoder.gradient_checkpointing_enable() | 595 | text_encoder.gradient_checkpointing_enable() |
591 | 596 | ||
592 | ema_unet = None | ||
593 | if args.use_ema: | 597 | if args.use_ema: |
594 | ema_unet = EMAModel( | 598 | ema_unet = EMAModel( |
595 | unet, | 599 | unet.parameters(), |
596 | inv_gamma=args.ema_inv_gamma, | 600 | inv_gamma=args.ema_inv_gamma, |
597 | power=args.ema_power, | 601 | power=args.ema_power, |
598 | max_value=args.ema_max_decay, | 602 | max_value=args.ema_max_decay, |
599 | device=accelerator.device | ||
600 | ) | 603 | ) |
601 | 604 | ||
602 | embeddings = patch_managed_embeddings(text_encoder) | 605 | embeddings = patch_managed_embeddings(text_encoder) |
@@ -748,52 +751,27 @@ def main(): | |||
748 | datamodule.prepare_data() | 751 | datamodule.prepare_data() |
749 | datamodule.setup() | 752 | datamodule.setup() |
750 | 753 | ||
751 | if args.num_class_images != 0: | 754 | train_dataloaders = datamodule.train_dataloaders |
752 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 755 | val_dataloader = datamodule.val_dataloader |
753 | |||
754 | if len(missing_data) != 0: | ||
755 | batched_data = [ | ||
756 | missing_data[i:i+args.sample_batch_size] | ||
757 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
758 | ] | ||
759 | |||
760 | pipeline = VlpnStableDiffusion( | ||
761 | text_encoder=text_encoder, | ||
762 | vae=vae, | ||
763 | unet=unet, | ||
764 | tokenizer=tokenizer, | ||
765 | scheduler=checkpoint_scheduler, | ||
766 | ).to(accelerator.device) | ||
767 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
768 | |||
769 | with torch.inference_mode(): | ||
770 | for batch in batched_data: | ||
771 | image_name = [item.class_image_path for item in batch] | ||
772 | prompt = [item.cprompt for item in batch] | ||
773 | nprompt = [item.nprompt for item in batch] | ||
774 | |||
775 | images = pipeline( | ||
776 | prompt=prompt, | ||
777 | negative_prompt=nprompt, | ||
778 | height=args.sample_image_size, | ||
779 | width=args.sample_image_size, | ||
780 | num_inference_steps=args.sample_steps | ||
781 | ).images | ||
782 | |||
783 | for i, image in enumerate(images): | ||
784 | image.save(image_name[i]) | ||
785 | |||
786 | del pipeline | ||
787 | 756 | ||
788 | if torch.cuda.is_available(): | 757 | if args.num_class_images != 0: |
789 | torch.cuda.empty_cache() | 758 | generate_class_images( |
790 | 759 | accelerator, | |
791 | train_dataloader = datamodule.train_dataloader() | 760 | text_encoder, |
792 | val_dataloader = datamodule.val_dataloader() | 761 | vae, |
762 | unet, | ||
763 | tokenizer, | ||
764 | checkpoint_scheduler, | ||
765 | datamodule.data_train, | ||
766 | args.sample_batch_size, | ||
767 | args.sample_image_size, | ||
768 | args.sample_steps | ||
769 | ) | ||
793 | 770 | ||
794 | # Scheduler and math around the number of training steps. | 771 | # Scheduler and math around the number of training steps. |
795 | overrode_max_train_steps = False | 772 | overrode_max_train_steps = False |
796 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 773 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
774 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
797 | if args.max_train_steps is None: | 775 | if args.max_train_steps is None: |
798 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 776 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
799 | overrode_max_train_steps = True | 777 | overrode_max_train_steps = True |
@@ -838,8 +816,12 @@ def main(): | |||
838 | # Keep text_encoder and vae in eval mode as we don't train these | 816 | # Keep text_encoder and vae in eval mode as we don't train these |
839 | vae.eval() | 817 | vae.eval() |
840 | 818 | ||
819 | if args.use_ema: | ||
820 | ema_unet.to(accelerator.device) | ||
821 | |||
841 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 822 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
842 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 823 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
824 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
843 | if overrode_max_train_steps: | 825 | if overrode_max_train_steps: |
844 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 826 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
845 | 827 | ||
@@ -847,11 +829,25 @@ def main(): | |||
847 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 829 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
848 | val_steps = num_val_steps_per_epoch * num_epochs | 830 | val_steps = num_val_steps_per_epoch * num_epochs |
849 | 831 | ||
832 | @contextmanager | ||
850 | def on_train(): | 833 | def on_train(): |
851 | tokenizer.train() | 834 | try: |
835 | tokenizer.train() | ||
836 | yield | ||
837 | finally: | ||
838 | pass | ||
852 | 839 | ||
840 | @contextmanager | ||
853 | def on_eval(): | 841 | def on_eval(): |
854 | tokenizer.eval() | 842 | try: |
843 | tokenizer.eval() | ||
844 | |||
845 | ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() | ||
846 | |||
847 | with ema_context: | ||
848 | yield | ||
849 | finally: | ||
850 | pass | ||
855 | 851 | ||
856 | loop = partial( | 852 | loop = partial( |
857 | run_model, | 853 | run_model, |
@@ -881,7 +877,7 @@ def main(): | |||
881 | accelerator, | 877 | accelerator, |
882 | text_encoder, | 878 | text_encoder, |
883 | optimizer, | 879 | optimizer, |
884 | train_dataloader, | 880 | train_dataloaders[0], |
885 | val_dataloader, | 881 | val_dataloader, |
886 | loop, | 882 | loop, |
887 | on_train=tokenizer.train, | 883 | on_train=tokenizer.train, |
@@ -962,88 +958,89 @@ def main(): | |||
962 | text_encoder.train() | 958 | text_encoder.train() |
963 | elif epoch == args.train_text_encoder_epochs: | 959 | elif epoch == args.train_text_encoder_epochs: |
964 | text_encoder.requires_grad_(False) | 960 | text_encoder.requires_grad_(False) |
965 | on_train() | ||
966 | 961 | ||
967 | for step, batch in enumerate(train_dataloader): | 962 | with on_train(): |
968 | with accelerator.accumulate(unet): | 963 | for train_dataloader in train_dataloaders: |
969 | loss, acc, bsz = loop(step, batch) | 964 | for step, batch in enumerate(train_dataloader): |
965 | with accelerator.accumulate(unet): | ||
966 | loss, acc, bsz = loop(step, batch) | ||
970 | 967 | ||
971 | accelerator.backward(loss) | 968 | accelerator.backward(loss) |
972 | 969 | ||
973 | if accelerator.sync_gradients: | 970 | if accelerator.sync_gradients: |
974 | params_to_clip = ( | 971 | params_to_clip = ( |
975 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 972 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
976 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | 973 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs |
977 | else unet.parameters() | 974 | else unet.parameters() |
978 | ) | 975 | ) |
979 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 976 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
980 | 977 | ||
981 | optimizer.step() | 978 | optimizer.step() |
982 | if not accelerator.optimizer_step_was_skipped: | 979 | if not accelerator.optimizer_step_was_skipped: |
983 | lr_scheduler.step() | 980 | lr_scheduler.step() |
984 | if args.use_ema: | 981 | if args.use_ema: |
985 | ema_unet.step(unet) | 982 | ema_unet.step(unet.parameters()) |
986 | optimizer.zero_grad(set_to_none=True) | 983 | optimizer.zero_grad(set_to_none=True) |
987 | 984 | ||
988 | avg_loss.update(loss.detach_(), bsz) | 985 | avg_loss.update(loss.detach_(), bsz) |
989 | avg_acc.update(acc.detach_(), bsz) | 986 | avg_acc.update(acc.detach_(), bsz) |
990 | 987 | ||
991 | # Checks if the accelerator has performed an optimization step behind the scenes | 988 | # Checks if the accelerator has performed an optimization step behind the scenes |
992 | if accelerator.sync_gradients: | 989 | if accelerator.sync_gradients: |
993 | local_progress_bar.update(1) | 990 | local_progress_bar.update(1) |
994 | global_progress_bar.update(1) | 991 | global_progress_bar.update(1) |
995 | 992 | ||
996 | global_step += 1 | 993 | global_step += 1 |
997 | 994 | ||
998 | logs = { | 995 | logs = { |
999 | "train/loss": avg_loss.avg.item(), | 996 | "train/loss": avg_loss.avg.item(), |
1000 | "train/acc": avg_acc.avg.item(), | 997 | "train/acc": avg_acc.avg.item(), |
1001 | "train/cur_loss": loss.item(), | 998 | "train/cur_loss": loss.item(), |
1002 | "train/cur_acc": acc.item(), | 999 | "train/cur_acc": acc.item(), |
1003 | "lr": lr_scheduler.get_last_lr()[0] | 1000 | "lr": lr_scheduler.get_last_lr()[0] |
1004 | } | 1001 | } |
1005 | if args.use_ema: | 1002 | if args.use_ema: |
1006 | logs["ema_decay"] = 1 - ema_unet.decay | 1003 | logs["ema_decay"] = 1 - ema_unet.decay |
1007 | 1004 | ||
1008 | accelerator.log(logs, step=global_step) | 1005 | accelerator.log(logs, step=global_step) |
1009 | 1006 | ||
1010 | local_progress_bar.set_postfix(**logs) | 1007 | local_progress_bar.set_postfix(**logs) |
1011 | 1008 | ||
1012 | if global_step >= args.max_train_steps: | 1009 | if global_step >= args.max_train_steps: |
1013 | break | 1010 | break |
1014 | 1011 | ||
1015 | accelerator.wait_for_everyone() | 1012 | accelerator.wait_for_everyone() |
1016 | 1013 | ||
1017 | unet.eval() | 1014 | unet.eval() |
1018 | text_encoder.eval() | 1015 | text_encoder.eval() |
1019 | on_eval() | ||
1020 | 1016 | ||
1021 | cur_loss_val = AverageMeter() | 1017 | cur_loss_val = AverageMeter() |
1022 | cur_acc_val = AverageMeter() | 1018 | cur_acc_val = AverageMeter() |
1023 | 1019 | ||
1024 | with torch.inference_mode(): | 1020 | with torch.inference_mode(): |
1025 | for step, batch in enumerate(val_dataloader): | 1021 | with on_eval(): |
1026 | loss, acc, bsz = loop(step, batch, True) | 1022 | for step, batch in enumerate(val_dataloader): |
1023 | loss, acc, bsz = loop(step, batch, True) | ||
1027 | 1024 | ||
1028 | loss = loss.detach_() | 1025 | loss = loss.detach_() |
1029 | acc = acc.detach_() | 1026 | acc = acc.detach_() |
1030 | 1027 | ||
1031 | cur_loss_val.update(loss, bsz) | 1028 | cur_loss_val.update(loss, bsz) |
1032 | cur_acc_val.update(acc, bsz) | 1029 | cur_acc_val.update(acc, bsz) |
1033 | 1030 | ||
1034 | avg_loss_val.update(loss, bsz) | 1031 | avg_loss_val.update(loss, bsz) |
1035 | avg_acc_val.update(acc, bsz) | 1032 | avg_acc_val.update(acc, bsz) |
1036 | 1033 | ||
1037 | local_progress_bar.update(1) | 1034 | local_progress_bar.update(1) |
1038 | global_progress_bar.update(1) | 1035 | global_progress_bar.update(1) |
1039 | 1036 | ||
1040 | logs = { | 1037 | logs = { |
1041 | "val/loss": avg_loss_val.avg.item(), | 1038 | "val/loss": avg_loss_val.avg.item(), |
1042 | "val/acc": avg_acc_val.avg.item(), | 1039 | "val/acc": avg_acc_val.avg.item(), |
1043 | "val/cur_loss": loss.item(), | 1040 | "val/cur_loss": loss.item(), |
1044 | "val/cur_acc": acc.item(), | 1041 | "val/cur_acc": acc.item(), |
1045 | } | 1042 | } |
1046 | local_progress_bar.set_postfix(**logs) | 1043 | local_progress_bar.set_postfix(**logs) |
1047 | 1044 | ||
1048 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 1045 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
1049 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 1046 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
diff --git a/train_ti.py b/train_ti.py index 38c9755..b4b602b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -22,7 +22,7 @@ from slugify import slugify | |||
22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import VlpnDataModule, VlpnDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
25 | from training.common import run_model | 25 | from training.common import run_model, generate_class_images |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
@@ -219,7 +219,6 @@ def parse_args(): | |||
219 | parser.add_argument( | 219 | parser.add_argument( |
220 | "--scale_lr", | 220 | "--scale_lr", |
221 | action="store_true", | 221 | action="store_true", |
222 | default=True, | ||
223 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 222 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
224 | ) | 223 | ) |
225 | parser.add_argument( | 224 | parser.add_argument( |
@@ -734,50 +733,23 @@ def main(): | |||
734 | ) | 733 | ) |
735 | datamodule.setup() | 734 | datamodule.setup() |
736 | 735 | ||
737 | if args.num_class_images != 0: | ||
738 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | ||
739 | |||
740 | if len(missing_data) != 0: | ||
741 | batched_data = [ | ||
742 | missing_data[i:i+args.sample_batch_size] | ||
743 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
744 | ] | ||
745 | |||
746 | pipeline = VlpnStableDiffusion( | ||
747 | text_encoder=text_encoder, | ||
748 | vae=vae, | ||
749 | unet=unet, | ||
750 | tokenizer=tokenizer, | ||
751 | scheduler=checkpoint_scheduler, | ||
752 | ).to(accelerator.device) | ||
753 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
754 | |||
755 | with torch.inference_mode(): | ||
756 | for batch in batched_data: | ||
757 | image_name = [item.class_image_path for item in batch] | ||
758 | prompt = [item.cprompt for item in batch] | ||
759 | nprompt = [item.nprompt for item in batch] | ||
760 | |||
761 | images = pipeline( | ||
762 | prompt=prompt, | ||
763 | negative_prompt=nprompt, | ||
764 | height=args.sample_image_size, | ||
765 | width=args.sample_image_size, | ||
766 | num_inference_steps=args.sample_steps | ||
767 | ).images | ||
768 | |||
769 | for i, image in enumerate(images): | ||
770 | image.save(image_name[i]) | ||
771 | |||
772 | del pipeline | ||
773 | |||
774 | if torch.cuda.is_available(): | ||
775 | torch.cuda.empty_cache() | ||
776 | |||
777 | train_dataloaders = datamodule.train_dataloaders | 736 | train_dataloaders = datamodule.train_dataloaders |
778 | default_train_dataloader = train_dataloaders[0] | ||
779 | val_dataloader = datamodule.val_dataloader | 737 | val_dataloader = datamodule.val_dataloader |
780 | 738 | ||
739 | if args.num_class_images != 0: | ||
740 | generate_class_images( | ||
741 | accelerator, | ||
742 | text_encoder, | ||
743 | vae, | ||
744 | unet, | ||
745 | tokenizer, | ||
746 | checkpoint_scheduler, | ||
747 | datamodule.data_train, | ||
748 | args.sample_batch_size, | ||
749 | args.sample_image_size, | ||
750 | args.sample_steps | ||
751 | ) | ||
752 | |||
781 | # Scheduler and math around the number of training steps. | 753 | # Scheduler and math around the number of training steps. |
782 | overrode_max_train_steps = False | 754 | overrode_max_train_steps = False |
783 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
@@ -898,7 +870,7 @@ def main(): | |||
898 | accelerator, | 870 | accelerator, |
899 | text_encoder, | 871 | text_encoder, |
900 | optimizer, | 872 | optimizer, |
901 | default_train_dataloader, | 873 | train_dataloaders[0], |
902 | val_dataloader, | 874 | val_dataloader, |
903 | loop, | 875 | loop, |
904 | on_train=on_train, | 876 | on_train=on_train, |
diff --git a/training/common.py b/training/common.py index ab2741a..67c2ab6 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -3,6 +3,60 @@ import torch.nn.functional as F | |||
3 | 3 | ||
4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
5 | 5 | ||
6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
7 | |||
8 | |||
9 | def generate_class_images( | ||
10 | accelerator, | ||
11 | text_encoder, | ||
12 | vae, | ||
13 | unet, | ||
14 | tokenizer, | ||
15 | scheduler, | ||
16 | data_train, | ||
17 | sample_batch_size, | ||
18 | sample_image_size, | ||
19 | sample_steps | ||
20 | ): | ||
21 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
22 | |||
23 | if len(missing_data) != 0: | ||
24 | batched_data = [ | ||
25 | missing_data[i:i+sample_batch_size] | ||
26 | for i in range(0, len(missing_data), sample_batch_size) | ||
27 | ] | ||
28 | |||
29 | pipeline = VlpnStableDiffusion( | ||
30 | text_encoder=text_encoder, | ||
31 | vae=vae, | ||
32 | unet=unet, | ||
33 | tokenizer=tokenizer, | ||
34 | scheduler=scheduler, | ||
35 | ).to(accelerator.device) | ||
36 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
37 | |||
38 | with torch.inference_mode(): | ||
39 | for batch in batched_data: | ||
40 | image_name = [item.class_image_path for item in batch] | ||
41 | prompt = [item.cprompt for item in batch] | ||
42 | nprompt = [item.nprompt for item in batch] | ||
43 | |||
44 | images = pipeline( | ||
45 | prompt=prompt, | ||
46 | negative_prompt=nprompt, | ||
47 | height=sample_image_size, | ||
48 | width=sample_image_size, | ||
49 | num_inference_steps=sample_steps | ||
50 | ).images | ||
51 | |||
52 | for i, image in enumerate(images): | ||
53 | image.save(image_name[i]) | ||
54 | |||
55 | del pipeline | ||
56 | |||
57 | if torch.cuda.is_available(): | ||
58 | torch.cuda.empty_cache() | ||
59 | |||
6 | 60 | ||
7 | def run_model( | 61 | def run_model( |
8 | vae: AutoencoderKL, | 62 | vae: AutoencoderKL, |