diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-07 17:10:06 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-07 17:10:06 +0100 |
| commit | 3353ffb64c280a938a0f2513d13b716c1fca8c02 (patch) | |
| tree | dbdc2ae1ddc5dc7758a2210e14e1fc9b18df7697 /train_ti.py | |
| parent | Made aspect ratio bucketing configurable (diff) | |
| download | textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.tar.gz textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.tar.bz2 textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.zip | |
Cleanup
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 60 |
1 files changed, 16 insertions, 44 deletions
diff --git a/train_ti.py b/train_ti.py index 38c9755..b4b602b 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -22,7 +22,7 @@ from slugify import slugify | |||
| 22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import VlpnDataModule, VlpnDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
| 25 | from training.common import run_model | 25 | from training.common import run_model, generate_class_images |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| @@ -219,7 +219,6 @@ def parse_args(): | |||
| 219 | parser.add_argument( | 219 | parser.add_argument( |
| 220 | "--scale_lr", | 220 | "--scale_lr", |
| 221 | action="store_true", | 221 | action="store_true", |
| 222 | default=True, | ||
| 223 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | 222 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", |
| 224 | ) | 223 | ) |
| 225 | parser.add_argument( | 224 | parser.add_argument( |
| @@ -734,50 +733,23 @@ def main(): | |||
| 734 | ) | 733 | ) |
| 735 | datamodule.setup() | 734 | datamodule.setup() |
| 736 | 735 | ||
| 737 | if args.num_class_images != 0: | ||
| 738 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | ||
| 739 | |||
| 740 | if len(missing_data) != 0: | ||
| 741 | batched_data = [ | ||
| 742 | missing_data[i:i+args.sample_batch_size] | ||
| 743 | for i in range(0, len(missing_data), args.sample_batch_size) | ||
| 744 | ] | ||
| 745 | |||
| 746 | pipeline = VlpnStableDiffusion( | ||
| 747 | text_encoder=text_encoder, | ||
| 748 | vae=vae, | ||
| 749 | unet=unet, | ||
| 750 | tokenizer=tokenizer, | ||
| 751 | scheduler=checkpoint_scheduler, | ||
| 752 | ).to(accelerator.device) | ||
| 753 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 754 | |||
| 755 | with torch.inference_mode(): | ||
| 756 | for batch in batched_data: | ||
| 757 | image_name = [item.class_image_path for item in batch] | ||
| 758 | prompt = [item.cprompt for item in batch] | ||
| 759 | nprompt = [item.nprompt for item in batch] | ||
| 760 | |||
| 761 | images = pipeline( | ||
| 762 | prompt=prompt, | ||
| 763 | negative_prompt=nprompt, | ||
| 764 | height=args.sample_image_size, | ||
| 765 | width=args.sample_image_size, | ||
| 766 | num_inference_steps=args.sample_steps | ||
| 767 | ).images | ||
| 768 | |||
| 769 | for i, image in enumerate(images): | ||
| 770 | image.save(image_name[i]) | ||
| 771 | |||
| 772 | del pipeline | ||
| 773 | |||
| 774 | if torch.cuda.is_available(): | ||
| 775 | torch.cuda.empty_cache() | ||
| 776 | |||
| 777 | train_dataloaders = datamodule.train_dataloaders | 736 | train_dataloaders = datamodule.train_dataloaders |
| 778 | default_train_dataloader = train_dataloaders[0] | ||
| 779 | val_dataloader = datamodule.val_dataloader | 737 | val_dataloader = datamodule.val_dataloader |
| 780 | 738 | ||
| 739 | if args.num_class_images != 0: | ||
| 740 | generate_class_images( | ||
| 741 | accelerator, | ||
| 742 | text_encoder, | ||
| 743 | vae, | ||
| 744 | unet, | ||
| 745 | tokenizer, | ||
| 746 | checkpoint_scheduler, | ||
| 747 | datamodule.data_train, | ||
| 748 | args.sample_batch_size, | ||
| 749 | args.sample_image_size, | ||
| 750 | args.sample_steps | ||
| 751 | ) | ||
| 752 | |||
| 781 | # Scheduler and math around the number of training steps. | 753 | # Scheduler and math around the number of training steps. |
| 782 | overrode_max_train_steps = False | 754 | overrode_max_train_steps = False |
| 783 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
| @@ -898,7 +870,7 @@ def main(): | |||
| 898 | accelerator, | 870 | accelerator, |
| 899 | text_encoder, | 871 | text_encoder, |
| 900 | optimizer, | 872 | optimizer, |
| 901 | default_train_dataloader, | 873 | train_dataloaders[0], |
| 902 | val_dataloader, | 874 | val_dataloader, |
| 903 | loop, | 875 | loop, |
| 904 | on_train=on_train, | 876 | on_train=on_train, |
