diff options
author | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
commit | e0b686b475885f0c8480f7173eaa7359adf17e27 (patch) | |
tree | 6ad882f152e63801d31230466e4d6468e7ada697 /train_ti.py | |
parent | Code simplifications, avoid autocast (diff) | |
download | textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2 textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip |
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py index a228795..6e30ac3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings, load_text_embedding | 23 | from common import load_text_embeddings, load_text_embedding, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
@@ -225,7 +225,7 @@ def parse_args(): | |||
225 | parser.add_argument( | 225 | parser.add_argument( |
226 | "--adam_weight_decay", | 226 | "--adam_weight_decay", |
227 | type=float, | 227 | type=float, |
228 | default=1e-2, | 228 | default=0, |
229 | help="Weight decay to use." | 229 | help="Weight decay to use." |
230 | ) | 230 | ) |
231 | parser.add_argument( | 231 | parser.add_argument( |
@@ -324,9 +324,8 @@ def parse_args(): | |||
324 | 324 | ||
325 | args = parser.parse_args() | 325 | args = parser.parse_args() |
326 | if args.config is not None: | 326 | if args.config is not None: |
327 | with open(args.config, 'rt') as f: | 327 | args = load_config(args.config) |
328 | args = parser.parse_args( | 328 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
329 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
330 | 329 | ||
331 | if args.train_data_file is None: | 330 | if args.train_data_file is None: |
332 | raise ValueError("You must specify --train_data_file") | 331 | raise ValueError("You must specify --train_data_file") |
@@ -407,7 +406,7 @@ class Checkpointer(CheckpointerBase): | |||
407 | 406 | ||
408 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
409 | # Save a checkpoint | 408 | # Save a checkpoint |
410 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] | 409 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] |
411 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 410 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
412 | 411 | ||
413 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 412 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |