summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--infer.py7
-rw-r--r--models/clip/embeddings.py4
-rw-r--r--models/clip/tokenizer.py2
-rw-r--r--train_ti.py2
-rw-r--r--training/lr.py7
6 files changed, 15 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py
index c505230..a60733a 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -151,7 +151,7 @@ class CSVDataModule():
151 151
152 num_images = len(items) 152 num_images = len(items)
153 153
154 valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.2) 154 valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10
155 valid_set_size = max(valid_set_size, 1) 155 valid_set_size = max(valid_set_size, 1)
156 train_set_size = num_images - valid_set_size 156 train_set_size = num_images - valid_set_size
157 157
diff --git a/infer.py b/infer.py
index c4d1e0d..b29b136 100644
--- a/infer.py
+++ b/infer.py
@@ -295,11 +295,10 @@ def generate(output_dir, pipeline, args):
295 for j, image in enumerate(images): 295 for j, image in enumerate(images):
296 image_dir = output_dir 296 image_dir = output_dir
297 if use_subdirs: 297 if use_subdirs:
298 idx = j % len(args.prompt) 298 image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100])
299 image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100])
300 image_dir.mkdir(parents=True, exist_ok=True) 299 image_dir.mkdir(parents=True, exist_ok=True)
301 image.save(image_dir.joinpath(f"{seed}_{j}.png")) 300 image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png"))
302 image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85) 301 image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85)
303 302
304 if torch.cuda.is_available(): 303 if torch.cuda.is_available():
305 torch.cuda.empty_cache() 304 torch.cuda.empty_cache()
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 9c3a56b..1280ebd 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
72 72
73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) 73 self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids])
74 self.temp_token_embedding.weight.data[token_ids] = initializer.to( 74 self.temp_token_embedding.weight.data[token_ids] = initializer.to(
75 dtype=self.temp_token_embedding.weight.dtype) 75 device=self.temp_token_embedding.weight.device,
76 dtype=self.temp_token_embedding.weight.dtype,
77 )
76 78
77 def load_embed(self, input_ids: list[int], filename: Path): 79 def load_embed(self, input_ids: list[int], filename: Path):
78 with safe_open(filename, framework="pt", device="cpu") as file: 80 with safe_open(filename, framework="pt", device="cpu") as file:
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 11a3df0..4e97ab5 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -48,7 +48,7 @@ def shuffle_none(tokens: list[int]):
48 48
49 49
50def shuffle_auto(tokens: list[int]): 50def shuffle_auto(tokens: list[int]):
51 if len(tokens) >= 4: 51 if len(tokens) >= 5:
52 return shuffle_between(tokens) 52 return shuffle_between(tokens)
53 if len(tokens) >= 3: 53 if len(tokens) >= 3:
54 return shuffle_trailing(tokens) 54 return shuffle_trailing(tokens)
diff --git a/train_ti.py b/train_ti.py
index 1b60f64..8ada98c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -889,7 +889,7 @@ def main():
889 on_train=on_train, 889 on_train=on_train,
890 on_eval=on_eval, 890 on_eval=on_eval,
891 ) 891 )
892 lr_finder.run(end_lr=1e3) 892 lr_finder.run(num_epochs=200, end_lr=1e3)
893 893
894 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 894 plt.savefig(basepath.joinpath("lr.png"), dpi=300)
895 plt.close() 895 plt.close()
diff --git a/training/lr.py b/training/lr.py
index c8dc040..3cdf994 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -12,7 +12,7 @@ from tqdm.auto import tqdm
12from training.util import AverageMeter 12from training.util import AverageMeter
13 13
14 14
15def noop(): 15def noop(*args, **kwards):
16 pass 16 pass
17 17
18 18
@@ -26,6 +26,7 @@ class LRFinder():
26 val_dataloader, 26 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], None] = noop,
29 on_clip: Callable[[], None] = noop,
29 on_eval: Callable[[], None] = noop 30 on_eval: Callable[[], None] = noop
30 ): 31 ):
31 self.accelerator = accelerator 32 self.accelerator = accelerator
@@ -35,6 +36,7 @@ class LRFinder():
35 self.val_dataloader = val_dataloader 36 self.val_dataloader = val_dataloader
36 self.loss_fn = loss_fn 37 self.loss_fn = loss_fn
37 self.on_train = on_train 38 self.on_train = on_train
39 self.on_clip = on_clip
38 self.on_eval = on_eval 40 self.on_eval = on_eval
39 41
40 # self.model_state = copy.deepcopy(model.state_dict()) 42 # self.model_state = copy.deepcopy(model.state_dict())
@@ -93,6 +95,9 @@ class LRFinder():
93 95
94 self.accelerator.backward(loss) 96 self.accelerator.backward(loss)
95 97
98 if self.accelerator.sync_gradients:
99 self.on_clip()
100
96 self.optimizer.step() 101 self.optimizer.step()
97 lr_scheduler.step() 102 lr_scheduler.step()
98 self.optimizer.zero_grad(set_to_none=True) 103 self.optimizer.zero_grad(set_to_none=True)