summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 17:10:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 17:10:06 +0100
commit3353ffb64c280a938a0f2513d13b716c1fca8c02 (patch)
treedbdc2ae1ddc5dc7758a2210e14e1fc9b18df7697 /train_ti.py
parentMade aspect ratio bucketing configurable (diff)
downloadtextual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.tar.gz
textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.tar.bz2
textual-inversion-diff-3353ffb64c280a938a0f2513d13b716c1fca8c02.zip
Cleanup
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py60
1 files changed, 16 insertions, 44 deletions
diff --git a/train_ti.py b/train_ti.py
index 38c9755..b4b602b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -22,7 +22,7 @@ from slugify import slugify
22from util import load_config, load_embeddings_from_dir 22from util import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import VlpnDataModule, VlpnDataItem 24from data.csv import VlpnDataModule, VlpnDataItem
25from training.common import run_model 25from training.common import run_model, generate_class_images
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 27from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args 28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
@@ -219,7 +219,6 @@ def parse_args():
219 parser.add_argument( 219 parser.add_argument(
220 "--scale_lr", 220 "--scale_lr",
221 action="store_true", 221 action="store_true",
222 default=True,
223 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 222 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
224 ) 223 )
225 parser.add_argument( 224 parser.add_argument(
@@ -734,50 +733,23 @@ def main():
734 ) 733 )
735 datamodule.setup() 734 datamodule.setup()
736 735
737 if args.num_class_images != 0:
738 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
739
740 if len(missing_data) != 0:
741 batched_data = [
742 missing_data[i:i+args.sample_batch_size]
743 for i in range(0, len(missing_data), args.sample_batch_size)
744 ]
745
746 pipeline = VlpnStableDiffusion(
747 text_encoder=text_encoder,
748 vae=vae,
749 unet=unet,
750 tokenizer=tokenizer,
751 scheduler=checkpoint_scheduler,
752 ).to(accelerator.device)
753 pipeline.set_progress_bar_config(dynamic_ncols=True)
754
755 with torch.inference_mode():
756 for batch in batched_data:
757 image_name = [item.class_image_path for item in batch]
758 prompt = [item.cprompt for item in batch]
759 nprompt = [item.nprompt for item in batch]
760
761 images = pipeline(
762 prompt=prompt,
763 negative_prompt=nprompt,
764 height=args.sample_image_size,
765 width=args.sample_image_size,
766 num_inference_steps=args.sample_steps
767 ).images
768
769 for i, image in enumerate(images):
770 image.save(image_name[i])
771
772 del pipeline
773
774 if torch.cuda.is_available():
775 torch.cuda.empty_cache()
776
777 train_dataloaders = datamodule.train_dataloaders 736 train_dataloaders = datamodule.train_dataloaders
778 default_train_dataloader = train_dataloaders[0]
779 val_dataloader = datamodule.val_dataloader 737 val_dataloader = datamodule.val_dataloader
780 738
739 if args.num_class_images != 0:
740 generate_class_images(
741 accelerator,
742 text_encoder,
743 vae,
744 unet,
745 tokenizer,
746 checkpoint_scheduler,
747 datamodule.data_train,
748 args.sample_batch_size,
749 args.sample_image_size,
750 args.sample_steps
751 )
752
781 # Scheduler and math around the number of training steps. 753 # Scheduler and math around the number of training steps.
782 overrode_max_train_steps = False 754 overrode_max_train_steps = False
783 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 755 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
@@ -898,7 +870,7 @@ def main():
898 accelerator, 870 accelerator,
899 text_encoder, 871 text_encoder,
900 optimizer, 872 optimizer,
901 default_train_dataloader, 873 train_dataloaders[0],
902 val_dataloader, 874 val_dataloader,
903 loop, 875 loop,
904 on_train=on_train, 876 on_train=on_train,