From 050dcbde0483c277cd632e05d8a0f73c87332785 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Thu, 20 Apr 2023 12:29:58 +0200
Subject: Update

---
 .../stable_diffusion/vlpn_stable_diffusion.py      |   2 +-
 train_lora.py                                      | 145 ++++++++++++++++-----
 train_ti.py                                        |  13 +-
 3 files changed, 126 insertions(+), 34 deletions(-)

diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a0dff54..aa3dbc6 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -301,7 +301,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
         init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
 
         t_start = max(num_inference_steps - init_timestep, 0)
-        timesteps = self.scheduler.timesteps[t_start:]
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
 
         timesteps = timesteps.to(device)
 
diff --git a/train_lora.py b/train_lora.py
index d0313fe..0ae8b31 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -302,6 +302,11 @@ def parse_args():
         default="none",
         help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
     )
+    parser.add_argument(
+        "--train_text_encoder_cycles",
+        default=999999,
+        help="Number of epochs the text encoder will be trained."
+    )
     parser.add_argument(
         "--find_lr",
         action="store_true",
@@ -919,6 +924,78 @@ def main():
         mid_point=args.lr_mid_point,
     )
 
+    # PTI
+    # --------------------------------------------------------------------------------
+
+    if len(args.placeholder_tokens) != 0:
+        filter_tokens = [token for token in args.filter_tokens if token in args.placeholder_tokens]
+
+        pti_datamodule = create_datamodule(
+            batch_size=args.train_batch_size,
+            filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections),
+        )
+        pti_datamodule.setup()
+
+        num_train_epochs = args.num_train_epochs
+        pti_sample_frequency = args.sample_frequency
+        if num_train_epochs is None:
+            num_train_epochs = math.ceil(
+                args.num_train_steps / len(pti_datamodule.train_dataset)
+            ) * args.gradient_accumulation_steps
+            pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps))
+        num_training_steps_per_epoch = math.ceil(
+            len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps)
+        num_train_steps = num_training_steps_per_epoch * num_train_epochs
+        if args.sample_num is not None:
+            pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
+
+        pti_project = "pti"
+
+        if accelerator.is_main_process:
+            accelerator.init_trackers(pti_project)
+
+        pti_sample_output_dir = output_dir / pti_project / "samples"
+
+        print("")
+        print(f"============ PTI ============")
+        print("")
+
+        pti_optimizer = create_optimizer([{
+            "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
+            "lr": args.learning_rate_emb,
+            "weight_decay": 0,
+        }])
+
+        pti_lr_scheduler = create_lr_scheduler(
+            "constant_with_warmup",
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            optimizer=pti_optimizer,
+            num_training_steps_per_epoch=len(pti_datamodule.train_dataloader),
+            train_epochs=num_train_epochs,
+            warmup_epochs=math.ceil(0.1 * num_train_epochs),
+        )
+
+        pti_checkpoint_output_dir = output_dir / pti_project / "model"
+
+        trainer(
+            strategy=lora_strategy,
+            train_dataloader=pti_datamodule.train_dataloader,
+            val_dataloader=pti_datamodule.val_dataloader,
+            optimizer=pti_optimizer,
+            lr_scheduler=pti_lr_scheduler,
+            num_train_epochs=num_train_epochs,
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            cycle=1,
+            pti_mode=True,
+            # --
+            group_labels=["emb"],
+            sample_output_dir=pti_sample_output_dir,
+            checkpoint_output_dir=pti_checkpoint_output_dir,
+            sample_frequency=pti_sample_frequency,
+        )
+
+        embeddings.persist()
+
     # LORA
     # --------------------------------------------------------------------------------
 
@@ -941,16 +1018,6 @@ def main():
     if args.sample_num is not None:
         lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
 
-    group_labels = []
-    if len(args.placeholder_tokens) != 0:
-        group_labels.append("emb")
-    group_labels += ["unet", "text"]
-
-    training_iter = 0
-    learning_rate_emb = args.learning_rate_emb
-    learning_rate_unet = args.learning_rate_unet
-    learning_rate_text = args.learning_rate_text
-
     lora_project = "lora"
 
     if accelerator.is_main_process:
@@ -958,7 +1025,11 @@ def main():
 
     lora_sample_output_dir = output_dir / lora_project / "samples"
 
+    training_iter = 0
     auto_cycles = list(args.auto_cycles)
+    learning_rate_emb = args.learning_rate_emb
+    learning_rate_unet = args.learning_rate_unet
+    learning_rate_text = args.learning_rate_text
     lr_scheduler = args.lr_scheduler
     lr_warmup_epochs = args.lr_warmup_epochs
     lr_cycles = args.lr_cycles
@@ -969,6 +1040,18 @@ def main():
         else:
             response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
 
+        if response.lower().strip() == "o":
+            if args.learning_rate_emb is not None:
+                learning_rate_emb = args.learning_rate_emb * 2
+            if args.learning_rate_unet is not None:
+                learning_rate_unet = args.learning_rate_unet * 2
+            if args.learning_rate_text is not None:
+                learning_rate_text = args.learning_rate_text * 2
+        else:
+            learning_rate_emb = args.learning_rate_emb
+            learning_rate_unet = args.learning_rate_unet
+            learning_rate_text = args.learning_rate_text
+
         if response.lower().strip() == "o":
             lr_scheduler = "one_cycle"
             lr_warmup_epochs = args.lr_warmup_epochs
@@ -986,28 +1069,32 @@ def main():
             break
 
         print("")
-        print(f"============ LoRA cycle {training_iter + 1} ============")
+        print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
         print("")
 
         params_to_optimize = []
+        group_labels = []
+
+        params_to_optimize.append({
+            "params": (
+                param
+                for param in unet.parameters()
+                if param.requires_grad
+            ),
+            "lr": learning_rate_unet,
+        })
+        group_labels.append("unet")
+
+        if training_iter < args.train_text_encoder_cycles:
+            # if len(args.placeholder_tokens) != 0:
+            #     params_to_optimize.append({
+            #         "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
+            #         "lr": learning_rate_emb,
+            #         "weight_decay": 0,
+            #     })
+            #     group_labels.append("emb")
 
-        if len(args.placeholder_tokens) != 0:
             params_to_optimize.append({
-                "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
-                "lr": learning_rate_emb,
-                "weight_decay": 0,
-            })
-            group_labels.append("emb")
-        params_to_optimize += [
-            {
-                "params": (
-                    param
-                    for param in unet.parameters()
-                    if param.requires_grad
-                ),
-                "lr": learning_rate_unet,
-            },
-            {
                 "params": (
                     param
                     for param in itertools.chain(
@@ -1017,8 +1104,8 @@ def main():
                     if param.requires_grad
                 ),
                 "lr": learning_rate_text,
-            },
-        ]
+            })
+            group_labels.append("text")
 
         lora_optimizer = create_optimizer(params_to_optimize)
 
diff --git a/train_ti.py b/train_ti.py
index b00b0d7..84ca296 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -906,9 +906,6 @@ def main():
         if args.sample_num is not None:
             sample_frequency = math.ceil(num_train_epochs / args.sample_num)
 
-        training_iter = 0
-        learning_rate = args.learning_rate
-
         project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
 
         if accelerator.is_main_process:
@@ -916,7 +913,9 @@ def main():
 
         sample_output_dir = output_dir / project / "samples"
 
+        training_iter = 0
         auto_cycles = list(args.auto_cycles)
+        learning_rate = args.learning_rate
         lr_scheduler = args.lr_scheduler
         lr_warmup_epochs = args.lr_warmup_epochs
         lr_cycles = args.lr_cycles
@@ -928,6 +927,12 @@ def main():
                 response = input(
                     "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ")
 
+            if response.lower().strip() == "o":
+                if args.learning_rate is not None:
+                    learning_rate = args.learning_rate * 2
+            else:
+                learning_rate = args.learning_rate
+
             if response.lower().strip() == "o":
                 lr_scheduler = "one_cycle"
                 lr_warmup_epochs = args.lr_warmup_epochs
@@ -945,7 +950,7 @@ def main():
                 break
 
             print("")
-            print(f"------------ TI cycle {training_iter + 1} ------------")
+            print(f"------------ TI cycle {training_iter + 1}: {response} ------------")
             print("")
 
             optimizer = create_optimizer(
-- 
cgit v1.2.3-70-g09d2