diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-01 22:08:21 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-01 22:08:21 +0100 |
| commit | 68164329b97f5cd79a56372dc6cace4b038afce8 (patch) | |
| tree | 50d404f764a8e6f85fadcc0a45dd0b8da3b6e507 | |
| parent | Cleanup (diff) | |
| download | textual-inversion-diff-68164329b97f5cd79a56372dc6cace4b038afce8.tar.gz textual-inversion-diff-68164329b97f5cd79a56372dc6cace4b038afce8.tar.bz2 textual-inversion-diff-68164329b97f5cd79a56372dc6cace4b038afce8.zip | |
Update
| -rw-r--r-- | models/clip/embeddings.py | 22 | ||||
| -rw-r--r-- | train_ti.py | 2 | ||||
| -rw-r--r-- | training/optimization.py | 6 |
3 files changed, 15 insertions, 15 deletions
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index f90e7c2..9c3a56b 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -56,23 +56,23 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 56 | if isinstance(token_ids, int): | 56 | if isinstance(token_ids, int): |
| 57 | token_ids = [token_ids] | 57 | token_ids = [token_ids] |
| 58 | 58 | ||
| 59 | if initializer is not None: | 59 | if initializer is None: |
| 60 | if isinstance(initializer, int): | 60 | initializer = token_ids |
| 61 | initializer = [initializer] | ||
| 62 | 61 | ||
| 63 | if isinstance(initializer, list): | 62 | if isinstance(initializer, int): |
| 64 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 63 | initializer = [initializer] |
| 65 | 64 | ||
| 66 | with torch.no_grad(): | 65 | if isinstance(initializer, list): |
| 67 | initializer = self.get_embed(initializer) | 66 | initializer = (initializer * len(token_ids))[:len(token_ids)] |
| 67 | |||
| 68 | with torch.no_grad(): | ||
| 69 | initializer = self.get_embed(initializer) | ||
| 68 | 70 | ||
| 69 | token_ids = torch.tensor(token_ids, dtype=torch.long) | 71 | token_ids = torch.tensor(token_ids, dtype=torch.long) |
| 70 | 72 | ||
| 71 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) | 73 | self.temp_token_ids = torch.cat([self.temp_token_ids, token_ids]) |
| 72 | 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | |
| 73 | if initializer is not None: | 75 | dtype=self.temp_token_embedding.weight.dtype) |
| 74 | self.temp_token_embedding.weight.data[token_ids] = initializer.to( | ||
| 75 | dtype=self.temp_token_embedding.weight.dtype) | ||
| 76 | 76 | ||
| 77 | def load_embed(self, input_ids: list[int], filename: Path): | 77 | def load_embed(self, input_ids: list[int], filename: Path): |
| 78 | with safe_open(filename, framework="pt", device="cpu") as file: | 78 | with safe_open(filename, framework="pt", device="cpu") as file: |
diff --git a/train_ti.py b/train_ti.py index 775b918..870bd40 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -250,7 +250,7 @@ def parse_args(): | |||
| 250 | parser.add_argument( | 250 | parser.add_argument( |
| 251 | "--lr_annealing_exp", | 251 | "--lr_annealing_exp", |
| 252 | type=int, | 252 | type=int, |
| 253 | default=1, | 253 | default=2, |
| 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 254 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
| 255 | ) | 255 | ) |
| 256 | parser.add_argument( | 256 | parser.add_argument( |
diff --git a/training/optimization.py b/training/optimization.py index a79944f..725599b 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -30,7 +30,7 @@ def get_one_cycle_schedule( | |||
| 30 | return min_lr + progress * (1 - min_lr) | 30 | return min_lr + progress * (1 - min_lr) |
| 31 | 31 | ||
| 32 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) | 32 | lr = 0.5 * (1.0 + math.cos(math.pi * (1 + progress))) |
| 33 | lr = lr ** warmup_exp | 33 | lr = lr ** (warmup_exp - (warmup_exp - 1) * progress) |
| 34 | return min_lr + lr * (1 - min_lr) | 34 | return min_lr + lr * (1 - min_lr) |
| 35 | 35 | ||
| 36 | if annealing == "linear": | 36 | if annealing == "linear": |
| @@ -47,11 +47,11 @@ def get_one_cycle_schedule( | |||
| 47 | 47 | ||
| 48 | if annealing == "half_cos": | 48 | if annealing == "half_cos": |
| 49 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) | 49 | lr = 1.0 + math.cos(math.pi * (0.5 + 0.5 * progress)) |
| 50 | lr = lr ** annealing_exp | 50 | lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) |
| 51 | return lr | 51 | return lr |
| 52 | 52 | ||
| 53 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) | 53 | lr = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| 54 | lr = lr ** annealing_exp | 54 | lr = lr ** (annealing_exp - (annealing_exp - 1) * progress) |
| 55 | return lr | 55 | return lr |
| 56 | 56 | ||
| 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 57 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
