From 37f5057e8c46693fb4ad02ad3b66b6c1ae46e79f Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Fri, 7 Apr 2023 11:31:21 +0200
Subject: Run PTI only if placeholder tokens arg isn't empty

---
 train_lora.py | 109 +++++++++++++++++++++++++++++-----------------------------
 train_ti.py   |  12 +++----
 2 files changed, 61 insertions(+), 60 deletions(-)

diff --git a/train_lora.py b/train_lora.py
index 6de3a75..daf1f6c 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -867,62 +867,63 @@ def main():
     # PTI
     # --------------------------------------------------------------------------------
 
-    pti_output_dir = output_dir / "pti"
-    pti_checkpoint_output_dir = pti_output_dir / "model"
-    pti_sample_output_dir = pti_output_dir / "samples"
-
-    pti_datamodule = create_datamodule(
-        batch_size=args.pti_batch_size,
-        filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
-    )
-    pti_datamodule.setup()
-
-    num_pti_epochs = args.num_pti_epochs
-    pti_sample_frequency = args.sample_frequency
-    if num_pti_epochs is None:
-        num_pti_epochs = math.ceil(
-            args.num_pti_steps / len(pti_datamodule.train_dataset)
-        ) * args.pti_gradient_accumulation_steps
-        pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps))
-
-    pti_optimizer = create_optimizer(
-        [
-            {
-                "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
-                "lr": args.learning_rate_pti,
-                "weight_decay": 0,
-            },
-        ]
-    )
+    if len(args.placeholder_tokens) != 0:
+        pti_output_dir = output_dir / "pti"
+        pti_checkpoint_output_dir = pti_output_dir / "model"
+        pti_sample_output_dir = pti_output_dir / "samples"
+
+        pti_datamodule = create_datamodule(
+            batch_size=args.pti_batch_size,
+            filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
+        )
+        pti_datamodule.setup()
+
+        num_pti_epochs = args.num_pti_epochs
+        pti_sample_frequency = args.sample_frequency
+        if num_pti_epochs is None:
+            num_pti_epochs = math.ceil(
+                args.num_pti_steps / len(pti_datamodule.train_dataset)
+            ) * args.pti_gradient_accumulation_steps
+            pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps))
+
+        pti_optimizer = create_optimizer(
+            [
+                {
+                    "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
+                    "lr": args.learning_rate_pti,
+                    "weight_decay": 0,
+                },
+            ]
+        )
 
-    pti_lr_scheduler = create_lr_scheduler(
-        gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
-        optimizer=pti_optimizer,
-        num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
-        train_epochs=num_pti_epochs,
-    )
+        pti_lr_scheduler = create_lr_scheduler(
+            gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
+            optimizer=pti_optimizer,
+            num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
+            train_epochs=num_pti_epochs,
+        )
 
-    metrics = trainer(
-        strategy=textual_inversion_strategy,
-        project="pti",
-        train_dataloader=pti_datamodule.train_dataloader,
-        val_dataloader=pti_datamodule.val_dataloader,
-        optimizer=pti_optimizer,
-        lr_scheduler=pti_lr_scheduler,
-        num_train_epochs=num_pti_epochs,
-        gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
-        # --
-        sample_output_dir=pti_sample_output_dir,
-        checkpoint_output_dir=pti_checkpoint_output_dir,
-        sample_frequency=pti_sample_frequency,
-        placeholder_tokens=args.placeholder_tokens,
-        placeholder_token_ids=placeholder_token_ids,
-        use_emb_decay=args.use_emb_decay,
-        emb_decay_target=args.emb_decay_target,
-        emb_decay=args.emb_decay,
-    )
+        metrics = trainer(
+            strategy=textual_inversion_strategy,
+            project="pti",
+            train_dataloader=pti_datamodule.train_dataloader,
+            val_dataloader=pti_datamodule.val_dataloader,
+            optimizer=pti_optimizer,
+            lr_scheduler=pti_lr_scheduler,
+            num_train_epochs=num_pti_epochs,
+            gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
+            # --
+            sample_output_dir=pti_sample_output_dir,
+            checkpoint_output_dir=pti_checkpoint_output_dir,
+            sample_frequency=pti_sample_frequency,
+            placeholder_tokens=args.placeholder_tokens,
+            placeholder_token_ids=placeholder_token_ids,
+            use_emb_decay=args.use_emb_decay,
+            emb_decay_target=args.emb_decay_target,
+            emb_decay=args.emb_decay,
+        )
 
-    plot_metrics(metrics, output_dir/"lr.png")
+        plot_metrics(metrics, pti_output_dir / "lr.png")
 
     # LORA
     # --------------------------------------------------------------------------------
@@ -994,7 +995,7 @@ def main():
         max_grad_norm=args.max_grad_norm,
     )
 
-    plot_metrics(metrics, output_dir/"lr.png")
+    plot_metrics(metrics, lora_output_dir / "lr.png")
 
 
 if __name__ == "__main__":
diff --git a/train_ti.py b/train_ti.py
index 344b412..c1c0eed 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -602,7 +602,7 @@ def main():
     elif args.mixed_precision == "bf16":
         weight_dtype = torch.bfloat16
 
-    logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG)
+    logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
 
     if args.seed is None:
         args.seed = torch.random.seed() >> 32
@@ -743,7 +743,7 @@ def main():
     else:
         raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
 
-    checkpoint_output_dir = output_dir/"checkpoints"
+    checkpoint_output_dir = output_dir / "checkpoints"
 
     trainer = partial(
         train,
@@ -782,11 +782,11 @@ def main():
 
     def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
         if len(placeholder_tokens) == 1:
-            sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}"
-            metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png"
+            sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}"
+            metrics_output_file = output_dir / f"{placeholder_tokens[0]}.png"
         else:
-            sample_output_dir = output_dir/"samples"
-            metrics_output_file = output_dir/f"lr.png"
+            sample_output_dir = output_dir / "samples"
+            metrics_output_file = output_dir / "lr.png"
 
         placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
             tokenizer=tokenizer,
-- 
cgit v1.2.3-70-g09d2