From e0b686b475885f0c8480f7173eaa7359adf17e27 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 26 Dec 2022 14:24:21 +0100 Subject: Set default dimensions to 768; add config inheritance --- train_ti.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index a228795..6e30ac3 100644 --- a/train_ti.py +++ b/train_ti.py @@ -20,7 +20,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from common import load_text_embeddings, load_text_embedding +from common import load_text_embeddings, load_text_embedding, load_config from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule, CSVDataItem from training.optimization import get_one_cycle_schedule @@ -225,7 +225,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=1e-2, + default=0, help="Weight decay to use." ) parser.add_argument( @@ -324,9 +324,8 @@ def parse_args(): args = parser.parse_args() if args.config is not None: - with open(args.config, 'rt') as f: - args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f)["args"])) + args = load_config(args.config) + args = parser.parse_args(namespace=argparse.Namespace(**args)) if args.train_data_file is None: raise ValueError("You must specify --train_data_file") @@ -407,7 +406,7 @@ class Checkpointer(CheckpointerBase): for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): # Save a checkpoint - learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] + learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) -- cgit v1.2.3-54-g00ecf