diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 60 |
1 files changed, 16 insertions, 44 deletions
diff --git a/train_ti.py b/train_ti.py index 38c9755..b4b602b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -22,7 +22,7 @@ from slugify import slugify | |||
22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import VlpnDataModule, VlpnDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
25 | from training.common import run_model | 25 | from training.common import run_model, generate_class_images |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
@@ -219,7 +219,6 @@ def parse_args(): | |||
219 | parser.add_argument( | 219 | parser.add_argument( |
220 | "--scale_lr", | 220 | "--scale_lr", |
221 | action="store_true", | 221 | action="store_true", |
222 | default=True, | ||
223 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 222 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
224 | ) | 223 | ) |
225 | parser.add_argument( | 224 | parser.add_argument( |
@@ -734,50 +733,23 @@ def main(): | |||
734 | ) | 733 | ) |
735 | datamodule.setup() | 734 | datamodule.setup() |
736 | 735 | ||
737 | if args.num_class_images != 0: | ||
738 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | ||
739 | |||
740 | if len(missing_data) != 0: | ||
741 | batched_data = [ | ||
742 | missing_data[i:i+args.sample_batch_size] | ||
743 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
744 | ] | ||
745 | |||
746 | pipeline = VlpnStableDiffusion( | ||
747 | text_encoder=text_encoder, | ||
748 | vae=vae, | ||
749 | unet=unet, | ||
750 | tokenizer=tokenizer, | ||
751 | scheduler=checkpoint_scheduler, | ||
752 | ).to(accelerator.device) | ||
753 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
754 | |||
755 | with torch.inference_mode(): | ||
756 | for batch in batched_data: | ||
757 | image_name = [item.class_image_path for item in batch] | ||
758 | prompt = [item.cprompt for item in batch] | ||
759 | nprompt = [item.nprompt for item in batch] | ||
760 | |||
761 | images = pipeline( | ||
762 | prompt=prompt, | ||
763 | negative_prompt=nprompt, | ||
764 | height=args.sample_image_size, | ||
765 | width=args.sample_image_size, | ||
766 | num_inference_steps=args.sample_steps | ||
767 | ).images | ||
768 | |||
769 | for i, image in enumerate(images): | ||
770 | image.save(image_name[i]) | ||
771 | |||
772 | del pipeline | ||
773 | |||
774 | if torch.cuda.is_available(): | ||
775 | torch.cuda.empty_cache() | ||
776 | |||
777 | train_dataloaders = datamodule.train_dataloaders | 736 | train_dataloaders = datamodule.train_dataloaders |
778 | default_train_dataloader = train_dataloaders[0] | ||
779 | val_dataloader = datamodule.val_dataloader | 737 | val_dataloader = datamodule.val_dataloader |
780 | 738 | ||
739 | if args.num_class_images != 0: | ||
740 | generate_class_images( | ||
741 | accelerator, | ||
742 | text_encoder, | ||
743 | vae, | ||
744 | unet, | ||
745 | tokenizer, | ||
746 | checkpoint_scheduler, | ||
747 | datamodule.data_train, | ||
748 | args.sample_batch_size, | ||
749 | args.sample_image_size, | ||
750 | args.sample_steps | ||
751 | ) | ||
752 | |||
781 | # Scheduler and math around the number of training steps. | 753 | # Scheduler and math around the number of training steps. |
782 | overrode_max_train_steps = False | 754 | overrode_max_train_steps = False |
783 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
@@ -898,7 +870,7 @@ def main(): | |||
898 | accelerator, | 870 | accelerator, |
899 | text_encoder, | 871 | text_encoder, |
900 | optimizer, | 872 | optimizer, |
901 | default_train_dataloader, | 873 | train_dataloaders[0], |
902 | val_dataloader, | 874 | val_dataloader, |
903 | loop, | 875 | loop, |
904 | on_train=on_train, | 876 | on_train=on_train, |