summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
committerVolpeon <git@volpeon.ink>2023-01-02 17:34:11 +0100
commit67aaba2159bcda4c0b8538b1580a40f01e8f0964 (patch)
treee1308417bde00609a5347bc39a8cd6583fd066f8
parentFix (diff)
downloadtextual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.gz
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.tar.bz2
textual-inversion-diff-67aaba2159bcda4c0b8538b1580a40f01e8f0964.zip
Update
-rw-r--r--train_dreambooth.py16
-rw-r--r--train_ti.py15
-rw-r--r--training/lr.py33
3 files changed, 55 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 05f6cb5..1e49474 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -565,8 +565,6 @@ def main():
565 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 565 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
566 args.pretrained_model_name_or_path, subfolder='scheduler') 566 args.pretrained_model_name_or_path, subfolder='scheduler')
567 567
568 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
569
570 vae.enable_slicing() 568 vae.enable_slicing()
571 vae.set_use_memory_efficient_attention_xformers(True) 569 vae.set_use_memory_efficient_attention_xformers(True)
572 unet.set_use_memory_efficient_attention_xformers(True) 570 unet.set_use_memory_efficient_attention_xformers(True)
@@ -893,7 +891,16 @@ def main():
893 accelerator.init_trackers("dreambooth", config=config) 891 accelerator.init_trackers("dreambooth", config=config)
894 892
895 if args.find_lr: 893 if args.find_lr:
896 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 894 lr_finder = LRFinder(
895 accelerator,
896 text_encoder,
897 optimizer,
898 train_dataloader,
899 val_dataloader,
900 loop,
901 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
902 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
903 )
897 lr_finder.run(min_lr=1e-4) 904 lr_finder.run(min_lr=1e-4)
898 905
899 plt.savefig(basepath.joinpath("lr.png")) 906 plt.savefig(basepath.joinpath("lr.png"))
@@ -965,11 +972,11 @@ def main():
965 local_progress_bar.reset() 972 local_progress_bar.reset()
966 973
967 unet.train() 974 unet.train()
968
969 if epoch < args.train_text_encoder_epochs: 975 if epoch < args.train_text_encoder_epochs:
970 text_encoder.train() 976 text_encoder.train()
971 elif epoch == args.train_text_encoder_epochs: 977 elif epoch == args.train_text_encoder_epochs:
972 text_encoder.requires_grad_(False) 978 text_encoder.requires_grad_(False)
979 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
973 980
974 for step, batch in enumerate(train_dataloader): 981 for step, batch in enumerate(train_dataloader):
975 with accelerator.accumulate(unet): 982 with accelerator.accumulate(unet):
@@ -1023,6 +1030,7 @@ def main():
1023 1030
1024 unet.eval() 1031 unet.eval()
1025 text_encoder.eval() 1032 text_encoder.eval()
1033 tokenizer.set_use_vector_shuffle(False)
1026 1034
1027 cur_loss_val = AverageMeter() 1035 cur_loss_val = AverageMeter()
1028 cur_acc_val = AverageMeter() 1036 cur_acc_val = AverageMeter()
diff --git a/train_ti.py b/train_ti.py
index 97dde1e..2b3f017 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -535,8 +535,6 @@ def main():
535 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 535 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
536 args.pretrained_model_name_or_path, subfolder='scheduler') 536 args.pretrained_model_name_or_path, subfolder='scheduler')
537 537
538 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
539
540 vae.enable_slicing() 538 vae.enable_slicing()
541 vae.set_use_memory_efficient_attention_xformers(True) 539 vae.set_use_memory_efficient_attention_xformers(True)
542 unet.set_use_memory_efficient_attention_xformers(True) 540 unet.set_use_memory_efficient_attention_xformers(True)
@@ -845,7 +843,16 @@ def main():
845 accelerator.init_trackers("textual_inversion", config=config) 843 accelerator.init_trackers("textual_inversion", config=config)
846 844
847 if args.find_lr: 845 if args.find_lr:
848 lr_finder = LRFinder(accelerator, text_encoder, optimizer, train_dataloader, val_dataloader, loop) 846 lr_finder = LRFinder(
847 accelerator,
848 text_encoder,
849 optimizer,
850 train_dataloader,
851 val_dataloader,
852 loop,
853 on_train=lambda: tokenizer.set_use_vector_shuffle(args.vector_shuffle),
854 on_eval=lambda: tokenizer.set_use_vector_shuffle(False)
855 )
849 lr_finder.run(min_lr=1e-4) 856 lr_finder.run(min_lr=1e-4)
850 857
851 plt.savefig(basepath.joinpath("lr.png")) 858 plt.savefig(basepath.joinpath("lr.png"))
@@ -915,6 +922,7 @@ def main():
915 local_progress_bar.reset() 922 local_progress_bar.reset()
916 923
917 text_encoder.train() 924 text_encoder.train()
925 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
918 926
919 for step, batch in enumerate(train_dataloader): 927 for step, batch in enumerate(train_dataloader):
920 with accelerator.accumulate(text_encoder): 928 with accelerator.accumulate(text_encoder):
@@ -955,6 +963,7 @@ def main():
955 accelerator.wait_for_everyone() 963 accelerator.wait_for_everyone()
956 964
957 text_encoder.eval() 965 text_encoder.eval()
966 tokenizer.set_use_vector_shuffle(False)
958 967
959 cur_loss_val = AverageMeter() 968 cur_loss_val = AverageMeter()
960 cur_acc_val = AverageMeter() 969 cur_acc_val = AverageMeter()
diff --git a/training/lr.py b/training/lr.py
index 3abd2f2..fe166ed 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,5 +1,6 @@
1import math 1import math
2import copy 2import copy
3from typing import Callable
3 4
4import matplotlib.pyplot as plt 5import matplotlib.pyplot as plt
5import numpy as np 6import numpy as np
@@ -10,19 +11,45 @@ from tqdm.auto import tqdm
10from training.util import AverageMeter 11from training.util import AverageMeter
11 12
12 13
14def noop():
15 pass
16
17
13class LRFinder(): 18class LRFinder():
14 def __init__(self, accelerator, model, optimizer, train_dataloader, val_dataloader, loss_fn): 19 def __init__(
20 self,
21 accelerator,
22 model,
23 optimizer,
24 train_dataloader,
25 val_dataloader,
26 loss_fn,
27 on_train: Callable[[], None] = noop,
28 on_eval: Callable[[], None] = noop
29 ):
15 self.accelerator = accelerator 30 self.accelerator = accelerator
16 self.model = model 31 self.model = model
17 self.optimizer = optimizer 32 self.optimizer = optimizer
18 self.train_dataloader = train_dataloader 33 self.train_dataloader = train_dataloader
19 self.val_dataloader = val_dataloader 34 self.val_dataloader = val_dataloader
20 self.loss_fn = loss_fn 35 self.loss_fn = loss_fn
36 self.on_train = on_train
37 self.on_eval = on_eval
21 38
22 # self.model_state = copy.deepcopy(model.state_dict()) 39 # self.model_state = copy.deepcopy(model.state_dict())
23 # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 40 # self.optimizer_state = copy.deepcopy(optimizer.state_dict())
24 41
25 def run(self, min_lr, skip_start=10, skip_end=5, num_epochs=100, num_train_batches=1, num_val_batches=math.inf, smooth_f=0.05, diverge_th=5): 42 def run(
43 self,
44 min_lr,
45 skip_start: int = 10,
46 skip_end: int = 5,
47 num_epochs: int = 100,
48 num_train_batches: int = 1,
49 num_val_batches: int = math.inf,
50 smooth_f: float = 0.05,
51 diverge_th: int = 5
52 ):
26 best_loss = None 53 best_loss = None
27 best_acc = None 54 best_acc = None
28 55
@@ -50,6 +77,7 @@ class LRFinder():
50 avg_acc = AverageMeter() 77 avg_acc = AverageMeter()
51 78
52 self.model.train() 79 self.model.train()
80 self.on_train()
53 81
54 for step, batch in enumerate(self.train_dataloader): 82 for step, batch in enumerate(self.train_dataloader):
55 if step >= num_train_batches: 83 if step >= num_train_batches:
@@ -67,6 +95,7 @@ class LRFinder():
67 progress_bar.update(1) 95 progress_bar.update(1)
68 96
69 self.model.eval() 97 self.model.eval()
98 self.on_eval()
70 99
71 with torch.inference_mode(): 100 with torch.inference_mode():
72 for step, batch in enumerate(self.val_dataloader): 101 for step, batch in enumerate(self.val_dataloader):