From 179a45253a5b3712f32bd127f693a6bb810a9c17 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 28 Mar 2023 16:24:22 +0200
Subject: Support num_train_steps arg again

---
 data/csv.py         |  8 ++++++--
 train_dreambooth.py | 17 +++++++++++------
 train_lora.py       | 17 +++++++++++------
 train_ti.py         | 30 +++++++++++++++++++++---------
 4 files changed, 49 insertions(+), 23 deletions(-)

diff --git a/data/csv.py b/data/csv.py
index 9770bec..c00ea07 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -143,7 +143,7 @@ class VlpnDataItem(NamedTuple):
 
 def keyword_filter(
     placeholder_tokens: Optional[list[str]],
-    collection: Optional[list[str]],
+    collections: Optional[list[str]],
     exclude_collections: Optional[list[str]],
     item: VlpnDataItem
 ):
@@ -152,11 +152,15 @@ def keyword_filter(
         for keyword in placeholder_tokens
         for part in item.prompt
     )
-    cond2 = collection is None or collection in item.collection
+    cond2 = collections is None or any(
+        collection in item.collection
+        for collection in collections
+    )
     cond3 = exclude_collections is None or not any(
         collection in item.collection
         for collection in exclude_collections
     )
+
     return cond1 and cond2 and cond3
 
 
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9345797..acb8287 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -4,6 +4,7 @@ import logging
 import itertools
 from pathlib import Path
 from functools import partial
+import math
 
 import torch
 import torch.utils.checkpoint
@@ -189,13 +190,12 @@ def parse_args():
     parser.add_argument(
         "--num_train_epochs",
         type=int,
-        default=100
+        default=None
     )
     parser.add_argument(
-        "--max_train_steps",
+        "--num_train_steps",
         type=int,
-        default=None,
-        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+        default=2000
     )
     parser.add_argument(
         "--gradient_accumulation_steps",
@@ -595,6 +595,11 @@ def main():
     )
     datamodule.setup()
 
+    num_train_epochs = args.num_train_epochs
+
+    if num_train_epochs is None:
+        num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
+
     params_to_optimize = (unet.parameters(), )
     if args.train_text_encoder_epochs != 0:
         params_to_optimize += (
@@ -619,7 +624,7 @@ def main():
         annealing_exp=args.lr_annealing_exp,
         cycles=args.lr_cycles,
         end_lr=1e2,
-        train_epochs=args.num_train_epochs,
+        train_epochs=num_train_epochs,
         warmup_epochs=args.lr_warmup_epochs,
     )
 
@@ -631,7 +636,7 @@ def main():
         seed=args.seed,
         optimizer=optimizer,
         lr_scheduler=lr_scheduler,
-        num_train_epochs=args.num_train_epochs,
+        num_train_epochs=num_train_epochs,
         gradient_accumulation_steps=args.gradient_accumulation_steps,
         sample_frequency=args.sample_frequency,
         offset_noise_strength=args.offset_noise_strength,
diff --git a/train_lora.py b/train_lora.py
index 7ecddf0..a9c6e52 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -4,6 +4,7 @@ import logging
 import itertools
 from pathlib import Path
 from functools import partial
+import math
 
 import torch
 import torch.utils.checkpoint
@@ -178,13 +179,12 @@ def parse_args():
     parser.add_argument(
         "--num_train_epochs",
         type=int,
-        default=100
+        default=None
     )
     parser.add_argument(
-        "--max_train_steps",
+        "--num_train_steps",
         type=int,
-        default=None,
-        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+        default=2000
     )
     parser.add_argument(
         "--gradient_accumulation_steps",
@@ -627,6 +627,11 @@ def main():
     )
     datamodule.setup()
 
+    num_train_epochs = args.num_train_epochs
+
+    if num_train_epochs is None:
+        num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
+
     optimizer = create_optimizer(
         itertools.chain(
             unet.parameters(),
@@ -647,7 +652,7 @@ def main():
         annealing_exp=args.lr_annealing_exp,
         cycles=args.lr_cycles,
         end_lr=1e2,
-        train_epochs=args.num_train_epochs,
+        train_epochs=num_train_epochs,
         warmup_epochs=args.lr_warmup_epochs,
     )
 
@@ -659,7 +664,7 @@ def main():
         seed=args.seed,
         optimizer=optimizer,
         lr_scheduler=lr_scheduler,
-        num_train_epochs=args.num_train_epochs,
+        num_train_epochs=num_train_epochs,
         gradient_accumulation_steps=args.gradient_accumulation_steps,
         sample_frequency=args.sample_frequency,
         offset_noise_strength=args.offset_noise_strength,
diff --git a/train_ti.py b/train_ti.py
index e4fd464..7bcc72f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,6 +3,7 @@ import datetime
 import logging
 from functools import partial
 from pathlib import Path
+import math
 
 import torch
 import torch.utils.checkpoint
@@ -207,7 +208,12 @@ def parse_args():
     parser.add_argument(
         "--num_train_epochs",
         type=int,
-        default=100
+        default=None
+    )
+    parser.add_argument(
+        "--num_train_steps",
+        type=int,
+        default=2000
     )
     parser.add_argument(
         "--gradient_accumulation_steps",
@@ -513,13 +519,13 @@ def parse_args():
     if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
         raise ValueError("--alias_tokens must be a list with an even number of items")
 
-    args.alias_tokens += [
-        item
-        for pair in zip(args.placeholder_tokens, args.initializer_tokens)
-        for item in pair
-    ]
-
     if args.sequential:
+        args.alias_tokens += [
+            item
+            for pair in zip(args.placeholder_tokens, args.initializer_tokens)
+            for item in pair
+        ]
+
         if isinstance(args.train_data_template, str):
             args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
 
@@ -607,6 +613,7 @@ def main():
             raise ValueError("--embeddings_dir must point to an existing directory")
 
         added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
+        embeddings.persist()
         print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
 
     if args.scale_lr:
@@ -682,7 +689,6 @@ def main():
         prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
         no_val=args.valid_set_size == 0,
         strategy=textual_inversion_strategy,
-        num_train_epochs=args.num_train_epochs,
         gradient_accumulation_steps=args.gradient_accumulation_steps,
         sample_frequency=args.sample_frequency,
         checkpoint_frequency=args.checkpoint_frequency,
@@ -752,6 +758,11 @@ def main():
         )
         datamodule.setup()
 
+        num_train_epochs = args.num_train_epochs
+
+        if num_train_epochs is None:
+            num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
+
         optimizer = create_optimizer(
             text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
             lr=args.learning_rate,
@@ -769,7 +780,7 @@ def main():
             annealing_exp=args.lr_annealing_exp,
             cycles=args.lr_cycles,
             end_lr=1e3,
-            train_epochs=args.num_train_epochs,
+            train_epochs=num_train_epochs,
             warmup_epochs=args.lr_warmup_epochs,
         )
 
@@ -779,6 +790,7 @@ def main():
             val_dataloader=datamodule.val_dataloader,
             optimizer=optimizer,
             lr_scheduler=lr_scheduler,
+            num_train_epochs=num_train_epochs,
             # --
             sample_output_dir=sample_output_dir,
             placeholder_tokens=placeholder_tokens,
-- 
cgit v1.2.3-70-g09d2