diff options
| -rw-r--r-- | train_dreambooth.py | 265 | ||||
| -rw-r--r-- | train_ti.py | 60 | ||||
| -rw-r--r-- | training/common.py | 54 |
3 files changed, 201 insertions, 178 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d265bcc..589af59 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -5,6 +5,7 @@ import datetime | |||
| 5 | import logging | 5 | import logging |
| 6 | from pathlib import Path | 6 | from pathlib import Path |
| 7 | from functools import partial | 7 | from functools import partial |
| 8 | from contextlib import contextmanager, nullcontext | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| @@ -23,7 +24,7 @@ from slugify import slugify | |||
| 23 | from util import load_config, load_embeddings_from_dir | 24 | from util import load_config, load_embeddings_from_dir |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import VlpnDataModule, VlpnDataItem | 26 | from data.csv import VlpnDataModule, VlpnDataItem |
| 26 | from training.common import run_model | 27 | from training.common import run_model, generate_class_images |
| 27 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| 28 | from training.lr import LRFinder | 29 | from training.lr import LRFinder |
| 29 | from training.util import AverageMeter, CheckpointerBase, save_args | 30 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -216,7 +217,6 @@ def parse_args(): | |||
| 216 | parser.add_argument( | 217 | parser.add_argument( |
| 217 | "--scale_lr", | 218 | "--scale_lr", |
| 218 | action="store_true", | 219 | action="store_true", |
| 219 | default=True, | ||
| 220 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 220 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| 221 | ) | 221 | ) |
| 222 | parser.add_argument( | 222 | parser.add_argument( |
| @@ -273,7 +273,6 @@ def parse_args(): | |||
| 273 | parser.add_argument( | 273 | parser.add_argument( |
| 274 | "--use_ema", | 274 | "--use_ema", |
| 275 | action="store_true", | 275 | action="store_true", |
| 276 | default=True, | ||
| 277 | help="Whether to use EMA model." | 276 | help="Whether to use EMA model." |
| 278 | ) | 277 | ) |
| 279 | parser.add_argument( | 278 | parser.add_argument( |
| @@ -294,7 +293,6 @@ def parse_args(): | |||
| 294 | parser.add_argument( | 293 | parser.add_argument( |
| 295 | "--use_8bit_adam", | 294 | "--use_8bit_adam", |
| 296 | action="store_true", | 295 | action="store_true", |
| 297 | default=True, | ||
| 298 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 296 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| 299 | ) | 297 | ) |
| 300 | parser.add_argument( | 298 | parser.add_argument( |
| @@ -486,17 +484,20 @@ class Checkpointer(CheckpointerBase): | |||
| 486 | def save_model(self): | 484 | def save_model(self): |
| 487 | print("Saving model...") | 485 | print("Saving model...") |
| 488 | 486 | ||
| 489 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 487 | unet = self.accelerator.unwrap_model(self.unet) |
| 490 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 488 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 491 | 489 | ||
| 492 | pipeline = VlpnStableDiffusion( | 490 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() |
| 493 | text_encoder=text_encoder, | 491 | |
| 494 | vae=self.vae, | 492 | with ema_context: |
| 495 | unet=unet, | 493 | pipeline = VlpnStableDiffusion( |
| 496 | tokenizer=self.tokenizer, | 494 | text_encoder=text_encoder, |
| 497 | scheduler=self.scheduler, | 495 | vae=self.vae, |
| 498 | ) | 496 | unet=unet, |
| 499 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | 497 | tokenizer=self.tokenizer, |
| 498 | scheduler=self.scheduler, | ||
| 499 | ) | ||
| 500 | pipeline.save_pretrained(self.output_dir.joinpath("model")) | ||
| 500 | 501 | ||
| 501 | del unet | 502 | del unet |
| 502 | del text_encoder | 503 | del text_encoder |
| @@ -507,28 +508,31 @@ class Checkpointer(CheckpointerBase): | |||
| 507 | 508 | ||
| 508 | @torch.no_grad() | 509 | @torch.no_grad() |
| 509 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 510 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 510 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 511 | unet = self.accelerator.unwrap_model(self.unet) |
| 511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 512 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 512 | 513 | ||
| 513 | orig_unet_dtype = unet.dtype | 514 | ema_context = self.ema_unet.apply_temporary(unet.parameters()) if self.ema_unet is not None else nullcontext() |
| 514 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 515 | 515 | ||
| 516 | unet.to(dtype=self.weight_dtype) | 516 | with ema_context: |
| 517 | text_encoder.to(dtype=self.weight_dtype) | 517 | orig_unet_dtype = unet.dtype |
| 518 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 518 | 519 | ||
| 519 | pipeline = VlpnStableDiffusion( | 520 | unet.to(dtype=self.weight_dtype) |
| 520 | text_encoder=text_encoder, | 521 | text_encoder.to(dtype=self.weight_dtype) |
| 521 | vae=self.vae, | 522 | |
| 522 | unet=unet, | 523 | pipeline = VlpnStableDiffusion( |
| 523 | tokenizer=self.tokenizer, | 524 | text_encoder=text_encoder, |
| 524 | scheduler=self.scheduler, | 525 | vae=self.vae, |
| 525 | ).to(self.accelerator.device) | 526 | unet=unet, |
| 526 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 527 | tokenizer=self.tokenizer, |
| 528 | scheduler=self.scheduler, | ||
| 529 | ).to(self.accelerator.device) | ||
| 530 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 527 | 531 | ||
| 528 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 532 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 529 | 533 | ||
| 530 | unet.to(dtype=orig_unet_dtype) | 534 | unet.to(dtype=orig_unet_dtype) |
| 531 | text_encoder.to(dtype=orig_text_encoder_dtype) | 535 | text_encoder.to(dtype=orig_text_encoder_dtype) |
| 532 | 536 | ||
| 533 | del unet | 537 | del unet |
| 534 | del text_encoder | 538 | del text_encoder |
| @@ -580,6 +584,7 @@ def main(): | |||
| 580 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') | 584 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') |
| 581 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 585 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 582 | args.pretrained_model_name_or_path, subfolder='scheduler') | 586 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 587 | ema_unet = None | ||
| 583 | 588 | ||
| 584 | vae.enable_slicing() | 589 | vae.enable_slicing() |
| 585 | vae.set_use_memory_efficient_attention_xformers(True) | 590 | vae.set_use_memory_efficient_attention_xformers(True) |
| @@ -589,14 +594,12 @@ def main(): | |||
| 589 | unet.enable_gradient_checkpointing() | 594 | unet.enable_gradient_checkpointing() |
| 590 | text_encoder.gradient_checkpointing_enable() | 595 | text_encoder.gradient_checkpointing_enable() |
| 591 | 596 | ||
| 592 | ema_unet = None | ||
| 593 | if args.use_ema: | 597 | if args.use_ema: |
| 594 | ema_unet = EMAModel( | 598 | ema_unet = EMAModel( |
| 595 | unet, | 599 | unet.parameters(), |
| 596 | inv_gamma=args.ema_inv_gamma, | 600 | inv_gamma=args.ema_inv_gamma, |
| 597 | power=args.ema_power, | 601 | power=args.ema_power, |
| 598 | max_value=args.ema_max_decay, | 602 | max_value=args.ema_max_decay, |
| 599 | device=accelerator.device | ||
| 600 | ) | 603 | ) |
| 601 | 604 | ||
| 602 | embeddings = patch_managed_embeddings(text_encoder) | 605 | embeddings = patch_managed_embeddings(text_encoder) |
| @@ -748,52 +751,27 @@ def main(): | |||
| 748 | datamodule.prepare_data() | 751 | datamodule.prepare_data() |
| 749 | datamodule.setup() | 752 | datamodule.setup() |
| 750 | 753 | ||
| 751 | if args.num_class_images != 0: | 754 | train_dataloaders = datamodule.train_dataloaders |
| 752 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 755 | val_dataloader = datamodule.val_dataloader |
| 753 | |||
| 754 | if len(missing_data) != 0: | ||
| 755 | batched_data = [ | ||
| 756 | missing_data[i:i+args.sample_batch_size] | ||
| 757 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
| 758 | ] | ||
| 759 | |||
| 760 | pipeline = VlpnStableDiffusion( | ||
| 761 | text_encoder=text_encoder, | ||
| 762 | vae=vae, | ||
| 763 | unet=unet, | ||
| 764 | tokenizer=tokenizer, | ||
| 765 | scheduler=checkpoint_scheduler, | ||
| 766 | ).to(accelerator.device) | ||
| 767 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 768 | |||
| 769 | with torch.inference_mode(): | ||
| 770 | for batch in batched_data: | ||
| 771 | image_name = [item.class_image_path for item in batch] | ||
| 772 | prompt = [item.cprompt for item in batch] | ||
| 773 | nprompt = [item.nprompt for item in batch] | ||
| 774 | |||
| 775 | images = pipeline( | ||
| 776 | prompt=prompt, | ||
| 777 | negative_prompt=nprompt, | ||
| 778 | height=args.sample_image_size, | ||
| 779 | width=args.sample_image_size, | ||
| 780 | num_inference_steps=args.sample_steps | ||
| 781 | ).images | ||
| 782 | |||
| 783 | for i, image in enumerate(images): | ||
| 784 | image.save(image_name[i]) | ||
| 785 | |||
| 786 | del pipeline | ||
| 787 | 756 | ||
| 788 | if torch.cuda.is_available(): | 757 | if args.num_class_images != 0: |
| 789 | torch.cuda.empty_cache() | 758 | generate_class_images( |
| 790 | 759 | accelerator, | |
| 791 | train_dataloader = datamodule.train_dataloader() | 760 | text_encoder, |
| 792 | val_dataloader = datamodule.val_dataloader() | 761 | vae, |
| 762 | unet, | ||
| 763 | tokenizer, | ||
| 764 | checkpoint_scheduler, | ||
| 765 | datamodule.data_train, | ||
| 766 | args.sample_batch_size, | ||
| 767 | args.sample_image_size, | ||
| 768 | args.sample_steps | ||
| 769 | ) | ||
| 793 | 770 | ||
| 794 | # Scheduler and math around the number of training steps. | 771 | # Scheduler and math around the number of training steps. |
| 795 | overrode_max_train_steps = False | 772 | overrode_max_train_steps = False |
| 796 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 773 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
| 774 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 797 | if args.max_train_steps is None: | 775 | if args.max_train_steps is None: |
| 798 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 776 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 799 | overrode_max_train_steps = True | 777 | overrode_max_train_steps = True |
| @@ -838,8 +816,12 @@ def main(): | |||
| 838 | # Keep text_encoder and vae in eval mode as we don't train these | 816 | # Keep text_encoder and vae in eval mode as we don't train these |
| 839 | vae.eval() | 817 | vae.eval() |
| 840 | 818 | ||
| 819 | if args.use_ema: | ||
| 820 | ema_unet.to(accelerator.device) | ||
| 821 | |||
| 841 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 822 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 842 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 823 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
| 824 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 843 | if overrode_max_train_steps: | 825 | if overrode_max_train_steps: |
| 844 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 826 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 845 | 827 | ||
| @@ -847,11 +829,25 @@ def main(): | |||
| 847 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 829 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 848 | val_steps = num_val_steps_per_epoch * num_epochs | 830 | val_steps = num_val_steps_per_epoch * num_epochs |
| 849 | 831 | ||
| 832 | @contextmanager | ||
| 850 | def on_train(): | 833 | def on_train(): |
| 851 | tokenizer.train() | 834 | try: |
| 835 | tokenizer.train() | ||
| 836 | yield | ||
| 837 | finally: | ||
| 838 | pass | ||
| 852 | 839 | ||
| 840 | @contextmanager | ||
| 853 | def on_eval(): | 841 | def on_eval(): |
| 854 | tokenizer.eval() | 842 | try: |
| 843 | tokenizer.eval() | ||
| 844 | |||
| 845 | ema_context = ema_unet.apply_temporary(unet.parameters()) if args.use_ema else nullcontext() | ||
| 846 | |||
| 847 | with ema_context: | ||
| 848 | yield | ||
| 849 | finally: | ||
| 850 | pass | ||
| 855 | 851 | ||
| 856 | loop = partial( | 852 | loop = partial( |
| 857 | run_model, | 853 | run_model, |
| @@ -881,7 +877,7 @@ def main(): | |||
| 881 | accelerator, | 877 | accelerator, |
| 882 | text_encoder, | 878 | text_encoder, |
| 883 | optimizer, | 879 | optimizer, |
| 884 | train_dataloader, | 880 | train_dataloaders[0], |
| 885 | val_dataloader, | 881 | val_dataloader, |
| 886 | loop, | 882 | loop, |
| 887 | on_train=tokenizer.train, | 883 | on_train=tokenizer.train, |
| @@ -962,88 +958,89 @@ def main(): | |||
| 962 | text_encoder.train() | 958 | text_encoder.train() |
| 963 | elif epoch == args.train_text_encoder_epochs: | 959 | elif epoch == args.train_text_encoder_epochs: |
| 964 | text_encoder.requires_grad_(False) | 960 | text_encoder.requires_grad_(False) |
| 965 | on_train() | ||
| 966 | 961 | ||
| 967 | for step, batch in enumerate(train_dataloader): | 962 | with on_train(): |
| 968 | with accelerator.accumulate(unet): | 963 | for train_dataloader in train_dataloaders: |
| 969 | loss, acc, bsz = loop(step, batch) | 964 | for step, batch in enumerate(train_dataloader): |
| 965 | with accelerator.accumulate(unet): | ||
| 966 | loss, acc, bsz = loop(step, batch) | ||
| 970 | 967 | ||
| 971 | accelerator.backward(loss) | 968 | accelerator.backward(loss) |
| 972 | 969 | ||
| 973 | if accelerator.sync_gradients: | 970 | if accelerator.sync_gradients: |
| 974 | params_to_clip = ( | 971 | params_to_clip = ( |
| 975 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 972 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
| 976 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | 973 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs |
| 977 | else unet.parameters() | 974 | else unet.parameters() |
| 978 | ) | 975 | ) |
| 979 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 976 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
| 980 | 977 | ||
| 981 | optimizer.step() | 978 | optimizer.step() |
| 982 | if not accelerator.optimizer_step_was_skipped: | 979 | if not accelerator.optimizer_step_was_skipped: |
| 983 | lr_scheduler.step() | 980 | lr_scheduler.step() |
| 984 | if args.use_ema: | 981 | if args.use_ema: |
| 985 | ema_unet.step(unet) | 982 | ema_unet.step(unet.parameters()) |
| 986 | optimizer.zero_grad(set_to_none=True) | 983 | optimizer.zero_grad(set_to_none=True) |
| 987 | 984 | ||
| 988 | avg_loss.update(loss.detach_(), bsz) | 985 | avg_loss.update(loss.detach_(), bsz) |
| 989 | avg_acc.update(acc.detach_(), bsz) | 986 | avg_acc.update(acc.detach_(), bsz) |
| 990 | 987 | ||
| 991 | # Checks if the accelerator has performed an optimization step behind the scenes | 988 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 992 | if accelerator.sync_gradients: | 989 | if accelerator.sync_gradients: |
| 993 | local_progress_bar.update(1) | 990 | local_progress_bar.update(1) |
| 994 | global_progress_bar.update(1) | 991 | global_progress_bar.update(1) |
| 995 | 992 | ||
| 996 | global_step += 1 | 993 | global_step += 1 |
| 997 | 994 | ||
| 998 | logs = { | 995 | logs = { |
| 999 | "train/loss": avg_loss.avg.item(), | 996 | "train/loss": avg_loss.avg.item(), |
| 1000 | "train/acc": avg_acc.avg.item(), | 997 | "train/acc": avg_acc.avg.item(), |
| 1001 | "train/cur_loss": loss.item(), | 998 | "train/cur_loss": loss.item(), |
| 1002 | "train/cur_acc": acc.item(), | 999 | "train/cur_acc": acc.item(), |
| 1003 | "lr": lr_scheduler.get_last_lr()[0] | 1000 | "lr": lr_scheduler.get_last_lr()[0] |
| 1004 | } | 1001 | } |
| 1005 | if args.use_ema: | 1002 | if args.use_ema: |
| 1006 | logs["ema_decay"] = 1 - ema_unet.decay | 1003 | logs["ema_decay"] = 1 - ema_unet.decay |
| 1007 | 1004 | ||
| 1008 | accelerator.log(logs, step=global_step) | 1005 | accelerator.log(logs, step=global_step) |
| 1009 | 1006 | ||
| 1010 | local_progress_bar.set_postfix(**logs) | 1007 | local_progress_bar.set_postfix(**logs) |
| 1011 | 1008 | ||
| 1012 | if global_step >= args.max_train_steps: | 1009 | if global_step >= args.max_train_steps: |
| 1013 | break | 1010 | break |
| 1014 | 1011 | ||
| 1015 | accelerator.wait_for_everyone() | 1012 | accelerator.wait_for_everyone() |
| 1016 | 1013 | ||
| 1017 | unet.eval() | 1014 | unet.eval() |
| 1018 | text_encoder.eval() | 1015 | text_encoder.eval() |
| 1019 | on_eval() | ||
| 1020 | 1016 | ||
| 1021 | cur_loss_val = AverageMeter() | 1017 | cur_loss_val = AverageMeter() |
| 1022 | cur_acc_val = AverageMeter() | 1018 | cur_acc_val = AverageMeter() |
| 1023 | 1019 | ||
| 1024 | with torch.inference_mode(): | 1020 | with torch.inference_mode(): |
| 1025 | for step, batch in enumerate(val_dataloader): | 1021 | with on_eval(): |
| 1026 | loss, acc, bsz = loop(step, batch, True) | 1022 | for step, batch in enumerate(val_dataloader): |
| 1023 | loss, acc, bsz = loop(step, batch, True) | ||
| 1027 | 1024 | ||
| 1028 | loss = loss.detach_() | 1025 | loss = loss.detach_() |
| 1029 | acc = acc.detach_() | 1026 | acc = acc.detach_() |
| 1030 | 1027 | ||
| 1031 | cur_loss_val.update(loss, bsz) | 1028 | cur_loss_val.update(loss, bsz) |
| 1032 | cur_acc_val.update(acc, bsz) | 1029 | cur_acc_val.update(acc, bsz) |
| 1033 | 1030 | ||
| 1034 | avg_loss_val.update(loss, bsz) | 1031 | avg_loss_val.update(loss, bsz) |
| 1035 | avg_acc_val.update(acc, bsz) | 1032 | avg_acc_val.update(acc, bsz) |
| 1036 | 1033 | ||
| 1037 | local_progress_bar.update(1) | 1034 | local_progress_bar.update(1) |
| 1038 | global_progress_bar.update(1) | 1035 | global_progress_bar.update(1) |
| 1039 | 1036 | ||
| 1040 | logs = { | 1037 | logs = { |
| 1041 | "val/loss": avg_loss_val.avg.item(), | 1038 | "val/loss": avg_loss_val.avg.item(), |
| 1042 | "val/acc": avg_acc_val.avg.item(), | 1039 | "val/acc": avg_acc_val.avg.item(), |
| 1043 | "val/cur_loss": loss.item(), | 1040 | "val/cur_loss": loss.item(), |
| 1044 | "val/cur_acc": acc.item(), | 1041 | "val/cur_acc": acc.item(), |
| 1045 | } | 1042 | } |
| 1046 | local_progress_bar.set_postfix(**logs) | 1043 | local_progress_bar.set_postfix(**logs) |
| 1047 | 1044 | ||
| 1048 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 1045 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 1049 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 1046 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
diff --git a/train_ti.py b/train_ti.py index 38c9755..b4b602b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -22,7 +22,7 @@ from slugify import slugify | |||
| 22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import VlpnDataModule, VlpnDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
| 25 | from training.common import run_model | 25 | from training.common import run_model, generate_class_images |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| @@ -219,7 +219,6 @@ def parse_args(): | |||
| 219 | parser.add_argument( | 219 | parser.add_argument( |
| 220 | "--scale_lr", | 220 | "--scale_lr", |
| 221 | action="store_true", | 221 | action="store_true", |
| 222 | default=True, | ||
| 223 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 222 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| 224 | ) | 223 | ) |
| 225 | parser.add_argument( | 224 | parser.add_argument( |
| @@ -734,50 +733,23 @@ def main(): | |||
| 734 | ) | 733 | ) |
| 735 | datamodule.setup() | 734 | datamodule.setup() |
| 736 | 735 | ||
| 737 | if args.num_class_images != 0: | ||
| 738 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | ||
| 739 | |||
| 740 | if len(missing_data) != 0: | ||
| 741 | batched_data = [ | ||
| 742 | missing_data[i:i+args.sample_batch_size] | ||
| 743 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
| 744 | ] | ||
| 745 | |||
| 746 | pipeline = VlpnStableDiffusion( | ||
| 747 | text_encoder=text_encoder, | ||
| 748 | vae=vae, | ||
| 749 | unet=unet, | ||
| 750 | tokenizer=tokenizer, | ||
| 751 | scheduler=checkpoint_scheduler, | ||
| 752 | ).to(accelerator.device) | ||
| 753 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 754 | |||
| 755 | with torch.inference_mode(): | ||
| 756 | for batch in batched_data: | ||
| 757 | image_name = [item.class_image_path for item in batch] | ||
| 758 | prompt = [item.cprompt for item in batch] | ||
| 759 | nprompt = [item.nprompt for item in batch] | ||
| 760 | |||
| 761 | images = pipeline( | ||
| 762 | prompt=prompt, | ||
| 763 | negative_prompt=nprompt, | ||
| 764 | height=args.sample_image_size, | ||
| 765 | width=args.sample_image_size, | ||
| 766 | num_inference_steps=args.sample_steps | ||
| 767 | ).images | ||
| 768 | |||
| 769 | for i, image in enumerate(images): | ||
| 770 | image.save(image_name[i]) | ||
| 771 | |||
| 772 | del pipeline | ||
| 773 | |||
| 774 | if torch.cuda.is_available(): | ||
| 775 | torch.cuda.empty_cache() | ||
| 776 | |||
| 777 | train_dataloaders = datamodule.train_dataloaders | 736 | train_dataloaders = datamodule.train_dataloaders |
| 778 | default_train_dataloader = train_dataloaders[0] | ||
| 779 | val_dataloader = datamodule.val_dataloader | 737 | val_dataloader = datamodule.val_dataloader |
| 780 | 738 | ||
| 739 | if args.num_class_images != 0: | ||
| 740 | generate_class_images( | ||
| 741 | accelerator, | ||
| 742 | text_encoder, | ||
| 743 | vae, | ||
| 744 | unet, | ||
| 745 | tokenizer, | ||
| 746 | checkpoint_scheduler, | ||
| 747 | datamodule.data_train, | ||
| 748 | args.sample_batch_size, | ||
| 749 | args.sample_image_size, | ||
| 750 | args.sample_steps | ||
| 751 | ) | ||
| 752 | |||
| 781 | # Scheduler and math around the number of training steps. | 753 | # Scheduler and math around the number of training steps. |
| 782 | overrode_max_train_steps = False | 754 | overrode_max_train_steps = False |
| 783 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
| @@ -898,7 +870,7 @@ def main(): | |||
| 898 | accelerator, | 870 | accelerator, |
| 899 | text_encoder, | 871 | text_encoder, |
| 900 | optimizer, | 872 | optimizer, |
| 901 | default_train_dataloader, | 873 | train_dataloaders[0], |
| 902 | val_dataloader, | 874 | val_dataloader, |
| 903 | loop, | 875 | loop, |
| 904 | on_train=on_train, | 876 | on_train=on_train, |
diff --git a/training/common.py b/training/common.py index ab2741a..67c2ab6 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -3,6 +3,60 @@ import torch.nn.functional as F | |||
| 3 | 3 | ||
| 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 5 | 5 | ||
| 6 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 7 | |||
| 8 | |||
| 9 | def generate_class_images( | ||
| 10 | accelerator, | ||
| 11 | text_encoder, | ||
| 12 | vae, | ||
| 13 | unet, | ||
| 14 | tokenizer, | ||
| 15 | scheduler, | ||
| 16 | data_train, | ||
| 17 | sample_batch_size, | ||
| 18 | sample_image_size, | ||
| 19 | sample_steps | ||
| 20 | ): | ||
| 21 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
| 22 | |||
| 23 | if len(missing_data) != 0: | ||
| 24 | batched_data = [ | ||
| 25 | missing_data[i:i+sample_batch_size] | ||
| 26 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 27 | ] | ||
| 28 | |||
| 29 | pipeline = VlpnStableDiffusion( | ||
| 30 | text_encoder=text_encoder, | ||
| 31 | vae=vae, | ||
| 32 | unet=unet, | ||
| 33 | tokenizer=tokenizer, | ||
| 34 | scheduler=scheduler, | ||
| 35 | ).to(accelerator.device) | ||
| 36 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 37 | |||
| 38 | with torch.inference_mode(): | ||
| 39 | for batch in batched_data: | ||
| 40 | image_name = [item.class_image_path for item in batch] | ||
| 41 | prompt = [item.cprompt for item in batch] | ||
| 42 | nprompt = [item.nprompt for item in batch] | ||
| 43 | |||
| 44 | images = pipeline( | ||
| 45 | prompt=prompt, | ||
| 46 | negative_prompt=nprompt, | ||
| 47 | height=sample_image_size, | ||
| 48 | width=sample_image_size, | ||
| 49 | num_inference_steps=sample_steps | ||
| 50 | ).images | ||
| 51 | |||
| 52 | for i, image in enumerate(images): | ||
| 53 | image.save(image_name[i]) | ||
| 54 | |||
| 55 | del pipeline | ||
| 56 | |||
| 57 | if torch.cuda.is_available(): | ||
| 58 | torch.cuda.empty_cache() | ||
| 59 | |||
| 6 | 60 | ||
| 7 | def run_model( | 61 | def run_model( |
| 8 | vae: AutoencoderKL, | 62 | vae: AutoencoderKL, |
