From 85953e46c0d36658293b1cd39e26f5f550b173f8 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Fri, 7 Apr 2023 21:41:53 +0200
Subject: Fix

---
 train_lora.py | 77 ++++++++++++++++++++++++++++++++---------------------------
 1 file changed, 42 insertions(+), 35 deletions(-)

diff --git a/train_lora.py b/train_lora.py
index 5b0a292..9f17495 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -444,6 +444,12 @@ def parse_args():
         default=1,
         help="How often to save a checkpoint and sample image",
     )
+    parser.add_argument(
+        "--pti_sample_frequency",
+        type=int,
+        default=1,
+        help="How often to save a checkpoint and sample image",
+    )
     parser.add_argument(
         "--sample_image_size",
         type=int,
@@ -887,47 +893,48 @@ def main():
         pti_datamodule.setup()
 
         num_pti_epochs = args.num_pti_epochs
-        pti_sample_frequency = args.sample_frequency
+        pti_sample_frequency = args.pti_sample_frequency
         if num_pti_epochs is None:
             num_pti_epochs = math.ceil(
                 args.num_pti_steps / len(pti_datamodule.train_dataset)
             ) * args.pti_gradient_accumulation_steps
             pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps))
 
-        pti_optimizer = create_optimizer(
-            [
-                {
-                    "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
-                    "lr": args.learning_rate_pti,
-                    "weight_decay": 0,
-                },
-            ]
-        )
-
-        pti_lr_scheduler = create_lr_scheduler(
-            gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
-            optimizer=pti_optimizer,
-            num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
-            train_epochs=num_pti_epochs,
-        )
-
-        metrics = trainer(
-            strategy=lora_strategy,
-            pti_mode=True,
-            project="pti",
-            train_dataloader=pti_datamodule.train_dataloader,
-            val_dataloader=pti_datamodule.val_dataloader,
-            optimizer=pti_optimizer,
-            lr_scheduler=pti_lr_scheduler,
-            num_train_epochs=num_pti_epochs,
-            gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
-            # --
-            sample_output_dir=pti_sample_output_dir,
-            checkpoint_output_dir=pti_checkpoint_output_dir,
-            sample_frequency=math.inf,
-        )
-
-        plot_metrics(metrics, pti_output_dir / "lr.png")
+        if num_pti_epochs > 0:
+            pti_optimizer = create_optimizer(
+                [
+                    {
+                        "params": text_encoder.text_model.embeddings.token_override_embedding.params.parameters(),
+                        "lr": args.learning_rate_pti,
+                        "weight_decay": 0,
+                    },
+                ]
+            )
+
+            pti_lr_scheduler = create_lr_scheduler(
+                gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
+                optimizer=pti_optimizer,
+                num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
+                train_epochs=num_pti_epochs,
+            )
+
+            metrics = trainer(
+                strategy=lora_strategy,
+                pti_mode=True,
+                project="pti",
+                train_dataloader=pti_datamodule.train_dataloader,
+                val_dataloader=pti_datamodule.val_dataloader,
+                optimizer=pti_optimizer,
+                lr_scheduler=pti_lr_scheduler,
+                num_train_epochs=num_pti_epochs,
+                gradient_accumulation_steps=args.pti_gradient_accumulation_steps,
+                # --
+                sample_output_dir=pti_sample_output_dir,
+                checkpoint_output_dir=pti_checkpoint_output_dir,
+                sample_frequency=pti_sample_frequency,
+            )
+
+            plot_metrics(metrics, pti_output_dir / "lr.png")
 
     # LORA
     # --------------------------------------------------------------------------------
-- 
cgit v1.2.3-70-g09d2