summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-26 14:24:21 +0100
committerVolpeon <git@volpeon.ink>2022-12-26 14:24:21 +0100
commite0b686b475885f0c8480f7173eaa7359adf17e27 (patch)
tree6ad882f152e63801d31230466e4d6468e7ada697 /train_ti.py
parentCode simplifications, avoid autocast (diff)
downloadtextual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz
textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2
textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip
Set default dimensions to 768; add config inheritance
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/train_ti.py b/train_ti.py
index a228795..6e30ac3 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings, load_text_embedding 23from common import load_text_embeddings, load_text_embedding, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
@@ -225,7 +225,7 @@ def parse_args():
225 parser.add_argument( 225 parser.add_argument(
226 "--adam_weight_decay", 226 "--adam_weight_decay",
227 type=float, 227 type=float,
228 default=1e-2, 228 default=0,
229 help="Weight decay to use." 229 help="Weight decay to use."
230 ) 230 )
231 parser.add_argument( 231 parser.add_argument(
@@ -324,9 +324,8 @@ def parse_args():
324 324
325 args = parser.parse_args() 325 args = parser.parse_args()
326 if args.config is not None: 326 if args.config is not None:
327 with open(args.config, 'rt') as f: 327 args = load_config(args.config)
328 args = parser.parse_args( 328 args = parser.parse_args(namespace=argparse.Namespace(**args))
329 namespace=argparse.Namespace(**json.load(f)["args"]))
330 329
331 if args.train_data_file is None: 330 if args.train_data_file is None:
332 raise ValueError("You must specify --train_data_file") 331 raise ValueError("You must specify --train_data_file")
@@ -407,7 +406,7 @@ class Checkpointer(CheckpointerBase):
407 406
408 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
409 # Save a checkpoint 408 # Save a checkpoint
410 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] 409 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id]
411 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 410 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
412 411
413 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 412 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)