diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-04 22:06:05 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-04 22:06:05 +0100 |
| commit | a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68 (patch) | |
| tree | 8bd97a745e1113b1035c504ec484e099f878aed0 | |
| parent | Various updates (diff) | |
| download | textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.gz textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.tar.bz2 textual-inversion-diff-a5e45e2c0dab95589e5fbaa4fe87d18484fbbe68.zip | |
Update
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | infer.py | 7 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 4 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 2 | ||||
| -rw-r--r-- | training/lr.py | 7 |
6 files changed, 15 insertions, 9 deletions
diff --git a/data/csv.py b/data/csv.py index c505230..a60733a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -151,7 +151,7 @@ class CSVDataModule(): | |||
| 151 | 151 | ||
| 152 | num_images = len(items) | 152 | num_images = len(items) |
| 153 | 153 | ||
| 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else int(num_images * 0.2) | 154 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 |
| 155 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
| 156 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
| 157 | 157 | ||
| @@ -295,11 +295,10 @@ def generate(output_dir, pipeline, args): | |||
| 295 | for j, image in enumerate(images): | 295 | for j, image in enumerate(images): |
| 296 | image_dir = output_dir | 296 | image_dir = output_dir |
| 297 | if use_subdirs: | 297 | if use_subdirs: |
| 298 | idx = j % len(args.prompt) | 298 | image_dir = image_dir.joinpath(slugify(args.prompt[j % len(args.prompt)])[:100]) |
| 299 | image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100]) | ||
| 300 | image_dir.mkdir(parents=True, exist_ok=True) | 299 | image_dir.mkdir(parents=True, exist_ok=True) |
| 301 | image.save(image_dir.joinpath(f"{seed}_{j}.png")) | 300 | image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) |
| 302 | image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85) | 301 | image.save(image_dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) |
| 303 | 302 | ||
| 304 | if torch.cuda.is_available(): | 303 | if torch.cuda.is_available(): |
| 305 | torch.cuda.empty_cache() | 304 | torch.cuda.empty_cache() |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 9c3a56b..1280ebd 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -72,7 +72,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 72 | 72 | ||
| 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( |
| 75 | dtype=self.temp_token_embedding.weight.dtype) | 75 | device=self.temp_token_embedding.weight.device, |
| 76 | dtype=self.temp_token_embedding.weight.dtype, | ||
| 77 | ) | ||
| 76 | 78 | ||
| 77 | def load_embed(self, input_ids: list[int], filename: Path): | 79 | def load_embed(self, input_ids: list[int], filename: Path): |
| 78 | with safe_open(filename, framework="pt", device="cpu") as file: | 80 | with safe_open(filename, framework="pt", device="cpu") as file: |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 11a3df0..4e97ab5 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -48,7 +48,7 @@ def shuffle_none(tokens: list[int]): | |||
| 48 | 48 | ||
| 49 | 49 | ||
| 50 | def shuffle_auto(tokens: list[int]): | 50 | def shuffle_auto(tokens: list[int]): |
| 51 | if len(tokens) >= 4: | 51 | if len(tokens) >= 5: |
| 52 | return shuffle_between(tokens) | 52 | return shuffle_between(tokens) |
| 53 | if len(tokens) >= 3: | 53 | if len(tokens) >= 3: |
| 54 | return shuffle_trailing(tokens) | 54 | return shuffle_trailing(tokens) |
diff --git a/train_ti.py b/train_ti.py index 1b60f64..8ada98c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -889,7 +889,7 @@ def main(): | |||
| 889 | on_train=on_train, | 889 | on_train=on_train, |
| 890 | on_eval=on_eval, | 890 | on_eval=on_eval, |
| 891 | ) | 891 | ) |
| 892 | lr_finder.run(end_lr=1e3) | 892 | lr_finder.run(num_epochs=200, end_lr=1e3) |
| 893 | 893 | ||
| 894 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 894 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) |
| 895 | plt.close() | 895 | plt.close() |
diff --git a/training/lr.py b/training/lr.py index c8dc040..3cdf994 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -12,7 +12,7 @@ from tqdm.auto import tqdm | |||
| 12 | from training.util import AverageMeter | 12 | from training.util import AverageMeter |
| 13 | 13 | ||
| 14 | 14 | ||
| 15 | def noop(): | 15 | def noop(*args, **kwards): |
| 16 | pass | 16 | pass |
| 17 | 17 | ||
| 18 | 18 | ||
| @@ -26,6 +26,7 @@ class LRFinder(): | |||
| 26 | val_dataloader, | 26 | val_dataloader, |
| 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 27 | loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
| 29 | on_clip: Callable[[], None] = noop, | ||
| 29 | on_eval: Callable[[], None] = noop | 30 | on_eval: Callable[[], None] = noop |
| 30 | ): | 31 | ): |
| 31 | self.accelerator = accelerator | 32 | self.accelerator = accelerator |
| @@ -35,6 +36,7 @@ class LRFinder(): | |||
| 35 | self.val_dataloader = val_dataloader | 36 | self.val_dataloader = val_dataloader |
| 36 | self.loss_fn = loss_fn | 37 | self.loss_fn = loss_fn |
| 37 | self.on_train = on_train | 38 | self.on_train = on_train |
| 39 | self.on_clip = on_clip | ||
| 38 | self.on_eval = on_eval | 40 | self.on_eval = on_eval |
| 39 | 41 | ||
| 40 | # self.model_state = copy.deepcopy(model.state_dict()) | 42 | # self.model_state = copy.deepcopy(model.state_dict()) |
| @@ -93,6 +95,9 @@ class LRFinder(): | |||
| 93 | 95 | ||
| 94 | self.accelerator.backward(loss) | 96 | self.accelerator.backward(loss) |
| 95 | 97 | ||
| 98 | if self.accelerator.sync_gradients: | ||
| 99 | self.on_clip() | ||
| 100 | |||
| 96 | self.optimizer.step() | 101 | self.optimizer.step() |
| 97 | lr_scheduler.step() | 102 | lr_scheduler.step() |
| 98 | self.optimizer.zero_grad(set_to_none=True) | 103 | self.optimizer.zero_grad(set_to_none=True) |
