From 36440e48ce279872d6e736bcb1bf57d13da73a11 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Mon, 16 Jan 2023 17:09:01 +0100
Subject: Moved multi-TI code from Dreambooth to TI script

---
 data/csv.py            |   2 +-
 train_dreambooth.py    | 135 +-----------------------------
 train_ti.py            | 221 +++++++++++++++++++++++++------------------------
 training/functional.py |  17 +++-
 4 files changed, 131 insertions(+), 244 deletions(-)

diff --git a/data/csv.py b/data/csv.py
index 85b98f8..6857b6f 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -282,7 +282,7 @@ class VlpnDataModule():
         collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
 
         if valid_set_size == 0:
-            data_train, data_val = items, items[:1]
+            data_train, data_val = items, []
         else:
             data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
 
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 1dc41b1..6511f9b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -199,23 +199,6 @@ def parse_args():
         type=int,
         default=100
     )
-    parser.add_argument(
-        "--ti_data_template",
-        type=str,
-        nargs='*',
-        default=[],
-    )
-    parser.add_argument(
-        "--ti_num_train_epochs",
-        type=int,
-        default=10
-    )
-    parser.add_argument(
-        "--ti_batch_size",
-        type=int,
-        default=1,
-        help="Batch size (per device) for the training dataloader."
-    )
     parser.add_argument(
         "--max_train_steps",
         type=int,
@@ -244,12 +227,6 @@ def parse_args():
         default=2e-6,
         help="Initial learning rate (after the potential warmup period) to use.",
     )
-    parser.add_argument(
-        "--ti_learning_rate",
-        type=float,
-        default=1e-2,
-        help="Initial learning rate (after the potential warmup period) to use.",
-    )
     parser.add_argument(
         "--scale_lr",
         action="store_true",
@@ -482,12 +459,6 @@ def parse_args():
     if len(args.placeholder_tokens) != len(args.num_vectors):
         raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
 
-    if isinstance(args.ti_data_template, str):
-        args.ti_data_template = [args.ti_data_template]
-
-    if len(args.ti_data_template) == 0:
-        raise ValueError("You must specify --ti_data_template")
-
     if isinstance(args.collection, str):
         args.collection = [args.collection]
 
@@ -521,8 +492,6 @@ def main():
 
     set_seed(args.seed)
 
-    seed_generator = torch.Generator().manual_seed(args.seed)
-
     save_args(output_dir, args)
 
     tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -583,107 +552,6 @@ def main():
         prior_loss_weight=args.prior_loss_weight,
     )
 
-    # Initial TI
-
-    print("Phase 1: Textual Inversion")
-
-    cur_dir = output_dir.joinpath("1-ti")
-    cur_dir.mkdir(parents=True, exist_ok=True)
-
-    for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
-        range(len(args.placeholder_tokens)),
-        args.placeholder_tokens,
-        args.initializer_tokens,
-        args.num_vectors,
-        args.ti_data_template
-    ):
-        cur_subdir = cur_dir.joinpath(placeholder_token)
-        cur_subdir.mkdir(parents=True, exist_ok=True)
-
-        placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
-            tokenizer=tokenizer,
-            embeddings=embeddings,
-            placeholder_tokens=[placeholder_token],
-            initializer_tokens=[initializer_token],
-            num_vectors=[num_vectors]
-        )
-
-        print(
-            f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
-
-        args.seed = seed_generator.seed()
-
-        datamodule = VlpnDataModule(
-            data_file=args.train_data_file,
-            batch_size=args.ti_batch_size,
-            tokenizer=tokenizer,
-            class_subdir=args.class_image_dir,
-            num_class_images=args.num_class_images,
-            size=args.resolution,
-            shuffle=not args.no_tag_shuffle,
-            template_key=data_template,
-            valid_set_size=1,
-            train_set_pad=args.train_set_pad,
-            valid_set_pad=args.valid_set_pad,
-            seed=args.seed,
-            filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
-            dtype=weight_dtype
-        )
-        datamodule.setup()
-
-        optimizer = optimizer_class(
-            text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
-            lr=args.ti_learning_rate,
-            betas=(args.adam_beta1, args.adam_beta2),
-            weight_decay=0.0,
-            eps=args.adam_epsilon,
-        )
-
-        lr_scheduler = get_scheduler(
-            "one_cycle",
-            optimizer=optimizer,
-            num_training_steps_per_epoch=len(datamodule.train_dataloader),
-            gradient_accumulation_steps=args.gradient_accumulation_steps,
-            train_epochs=args.ti_num_train_epochs,
-        )
-
-        trainer(
-            callbacks_fn=textual_inversion_strategy,
-            project="textual_inversion",
-            train_dataloader=datamodule.train_dataloader,
-            val_dataloader=datamodule.val_dataloader,
-            seed=args.seed,
-            optimizer=optimizer,
-            lr_scheduler=lr_scheduler,
-            num_train_epochs=args.ti_num_train_epochs,
-            sample_frequency=args.ti_num_train_epochs // 5,
-            checkpoint_frequency=9999999,
-            # --
-            tokenizer=tokenizer,
-            sample_scheduler=sample_scheduler,
-            output_dir=cur_subdir,
-            placeholder_tokens=[placeholder_token],
-            placeholder_token_ids=placeholder_token_ids,
-            learning_rate=args.ti_learning_rate,
-            gradient_checkpointing=args.gradient_checkpointing,
-            use_emb_decay=True,
-            sample_batch_size=args.sample_batch_size,
-            sample_num_batches=args.sample_batches,
-            sample_num_steps=args.sample_steps,
-            sample_image_size=args.sample_image_size,
-        )
-
-        embeddings.persist()
-
-    # Dreambooth
-
-    print("Phase 2: Dreambooth")
-
-    cur_dir = output_dir.joinpath("2-db")
-    cur_dir.mkdir(parents=True, exist_ok=True)
-
-    args.seed = seed_generator.seed()
-
     datamodule = VlpnDataModule(
         data_file=args.train_data_file,
         batch_size=args.train_batch_size,
@@ -746,12 +614,13 @@ def main():
         seed=args.seed,
         optimizer=optimizer,
         lr_scheduler=lr_scheduler,
+        prepare_unet=True,
         num_train_epochs=args.num_train_epochs,
         sample_frequency=args.sample_frequency,
         # --
         tokenizer=tokenizer,
         sample_scheduler=sample_scheduler,
-        output_dir=cur_dir,
+        output_dir=output_dir,
         train_text_encoder_epochs=args.train_text_encoder_epochs,
         max_grad_norm=args.max_grad_norm,
         use_ema=args.use_ema,
diff --git a/train_ti.py b/train_ti.py
index 7aecdef..adba8d4 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -51,6 +51,7 @@ def parse_args():
     parser.add_argument(
         "--train_data_template",
         type=str,
+        nargs='*',
         default="template",
     )
     parser.add_argument(
@@ -468,11 +469,17 @@ def parse_args():
         args.num_vectors = 1
 
     if isinstance(args.num_vectors, int):
-        args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
+        args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
 
     if len(args.placeholder_tokens) != len(args.num_vectors):
         raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
 
+    if isinstance(args.train_data_template, str):
+        args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
+
+    if len(args.placeholder_tokens) != len(args.train_data_template):
+        raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
+
     if isinstance(args.collection, str):
         args.collection = [args.collection]
 
@@ -507,6 +514,8 @@ def main():
 
     set_seed(args.seed)
 
+    seed_generator = torch.Generator().manual_seed(args.seed)
+
     save_args(output_dir, args)
 
     tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -531,19 +540,6 @@ def main():
         added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
         print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
 
-    placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
-        tokenizer=tokenizer,
-        embeddings=embeddings,
-        placeholder_tokens=args.placeholder_tokens,
-        initializer_tokens=args.initializer_tokens,
-        num_vectors=args.num_vectors
-    )
-
-    if len(placeholder_token_ids) != 0:
-        initializer_token_id_lens = [len(id) for id in initializer_token_ids]
-        placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
-        print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
-
     if args.scale_lr:
         args.learning_rate = (
             args.learning_rate * args.gradient_accumulation_steps *
@@ -566,43 +562,6 @@ def main():
     elif args.mixed_precision == "bf16":
         weight_dtype = torch.bfloat16
 
-    datamodule = VlpnDataModule(
-        data_file=args.train_data_file,
-        batch_size=args.train_batch_size,
-        tokenizer=tokenizer,
-        class_subdir=args.class_image_dir,
-        num_class_images=args.num_class_images,
-        size=args.resolution,
-        num_buckets=args.num_buckets,
-        progressive_buckets=args.progressive_buckets,
-        bucket_step_size=args.bucket_step_size,
-        bucket_max_pixels=args.bucket_max_pixels,
-        dropout=args.tag_dropout,
-        shuffle=not args.no_tag_shuffle,
-        template_key=args.train_data_template,
-        valid_set_size=args.valid_set_size,
-        train_set_pad=args.train_set_pad,
-        valid_set_pad=args.valid_set_pad,
-        seed=args.seed,
-        filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
-        dtype=weight_dtype
-    )
-    datamodule.setup()
-
-    if args.num_class_images != 0:
-        generate_class_images(
-            accelerator,
-            text_encoder,
-            vae,
-            unet,
-            tokenizer,
-            sample_scheduler,
-            datamodule.train_dataset,
-            args.sample_batch_size,
-            args.sample_image_size,
-            args.sample_steps
-        )
-
     trainer = partial(
         train,
         accelerator=accelerator,
@@ -615,63 +574,111 @@ def main():
         callbacks_fn=textual_inversion_strategy
     )
 
-    optimizer = optimizer_class(
-        text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
-        lr=args.learning_rate,
-        betas=(args.adam_beta1, args.adam_beta2),
-        weight_decay=args.adam_weight_decay,
-        eps=args.adam_epsilon,
-        amsgrad=args.adam_amsgrad,
-    )
+    for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
+        range(len(args.placeholder_tokens)),
+        args.placeholder_tokens,
+        args.initializer_tokens,
+        args.num_vectors,
+        args.train_data_template
+    ):
+        cur_subdir = output_dir.joinpath(placeholder_token)
+        cur_subdir.mkdir(parents=True, exist_ok=True)
+
+        placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
+            tokenizer=tokenizer,
+            embeddings=embeddings,
+            placeholder_tokens=[placeholder_token],
+            initializer_tokens=[initializer_token],
+            num_vectors=[num_vectors]
+        )
 
-    lr_scheduler = get_scheduler(
-        args.lr_scheduler,
-        optimizer=optimizer,
-        num_training_steps_per_epoch=len(datamodule.train_dataloader),
-        gradient_accumulation_steps=args.gradient_accumulation_steps,
-        min_lr=args.lr_min_lr,
-        warmup_func=args.lr_warmup_func,
-        annealing_func=args.lr_annealing_func,
-        warmup_exp=args.lr_warmup_exp,
-        annealing_exp=args.lr_annealing_exp,
-        cycles=args.lr_cycles,
-        train_epochs=args.num_train_epochs,
-        warmup_epochs=args.lr_warmup_epochs,
-    )
-
-    trainer(
-        project="textual_inversion",
-        train_dataloader=datamodule.train_dataloader,
-        val_dataloader=datamodule.val_dataloader,
-        optimizer=optimizer,
-        lr_scheduler=lr_scheduler,
-        num_train_epochs=args.num_train_epochs,
-        sample_frequency=args.sample_frequency,
-        checkpoint_frequency=args.checkpoint_frequency,
-        global_step_offset=global_step_offset,
-        with_prior_preservation=args.num_class_images != 0,
-        prior_loss_weight=args.prior_loss_weight,
-        # --
-        tokenizer=tokenizer,
-        sample_scheduler=sample_scheduler,
-        output_dir=output_dir,
-        placeholder_tokens=args.placeholder_tokens,
-        placeholder_token_ids=placeholder_token_ids,
-        learning_rate=args.learning_rate,
-        gradient_checkpointing=args.gradient_checkpointing,
-        use_emb_decay=args.use_emb_decay,
-        emb_decay_target=args.emb_decay_target,
-        emb_decay_factor=args.emb_decay_factor,
-        emb_decay_start=args.emb_decay_start,
-        use_ema=args.use_ema,
-        ema_inv_gamma=args.ema_inv_gamma,
-        ema_power=args.ema_power,
-        ema_max_decay=args.ema_max_decay,
-        sample_batch_size=args.sample_batch_size,
-        sample_num_batches=args.sample_batches,
-        sample_num_steps=args.sample_steps,
-        sample_image_size=args.sample_image_size,
-    )
+        print(
+            f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
+
+        args.seed = seed_generator.seed()
+
+        datamodule = VlpnDataModule(
+            data_file=args.train_data_file,
+            batch_size=args.train_batch_size,
+            tokenizer=tokenizer,
+            class_subdir=args.class_image_dir,
+            num_class_images=args.num_class_images,
+            size=args.resolution,
+            num_buckets=args.num_buckets,
+            progressive_buckets=args.progressive_buckets,
+            bucket_step_size=args.bucket_step_size,
+            bucket_max_pixels=args.bucket_max_pixels,
+            dropout=args.tag_dropout,
+            shuffle=not args.no_tag_shuffle,
+            template_key=data_template,
+            valid_set_size=args.valid_set_size,
+            train_set_pad=args.train_set_pad,
+            valid_set_pad=args.valid_set_pad,
+            seed=args.seed,
+            filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
+            dtype=weight_dtype
+        )
+        datamodule.setup()
+
+        optimizer = optimizer_class(
+            text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
+            lr=args.learning_rate,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            eps=args.adam_epsilon,
+            amsgrad=args.adam_amsgrad,
+        )
+
+        lr_scheduler = get_scheduler(
+            args.lr_scheduler,
+            optimizer=optimizer,
+            num_training_steps_per_epoch=len(datamodule.train_dataloader),
+            gradient_accumulation_steps=args.gradient_accumulation_steps,
+            min_lr=args.lr_min_lr,
+            warmup_func=args.lr_warmup_func,
+            annealing_func=args.lr_annealing_func,
+            warmup_exp=args.lr_warmup_exp,
+            annealing_exp=args.lr_annealing_exp,
+            cycles=args.lr_cycles,
+            train_epochs=args.num_train_epochs,
+            warmup_epochs=args.lr_warmup_epochs,
+        )
+
+        trainer(
+            project="textual_inversion",
+            train_dataloader=datamodule.train_dataloader,
+            val_dataloader=datamodule.val_dataloader,
+            optimizer=optimizer,
+            lr_scheduler=lr_scheduler,
+            num_train_epochs=args.num_train_epochs,
+            sample_frequency=args.sample_frequency,
+            checkpoint_frequency=args.checkpoint_frequency,
+            global_step_offset=global_step_offset,
+            with_prior_preservation=args.num_class_images != 0,
+            prior_loss_weight=args.prior_loss_weight,
+            # --
+            tokenizer=tokenizer,
+            sample_scheduler=sample_scheduler,
+            output_dir=cur_subdir,
+            placeholder_tokens=[placeholder_token],
+            placeholder_token_ids=placeholder_token_ids,
+            learning_rate=args.learning_rate,
+            gradient_checkpointing=args.gradient_checkpointing,
+            use_emb_decay=args.use_emb_decay,
+            emb_decay_target=args.emb_decay_target,
+            emb_decay_factor=args.emb_decay_factor,
+            emb_decay_start=args.emb_decay_start,
+            use_ema=args.use_ema,
+            ema_inv_gamma=args.ema_inv_gamma,
+            ema_power=args.ema_power,
+            ema_max_decay=args.ema_max_decay,
+            sample_batch_size=args.sample_batch_size,
+            sample_num_batches=args.sample_batches,
+            sample_num_steps=args.sample_steps,
+            sample_image_size=args.sample_image_size,
+        )
+
+        embeddings.persist()
 
 
 if __name__ == "__main__":
diff --git a/training/functional.py b/training/functional.py
index b6b5d87..1548784 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -515,6 +515,7 @@ def train(
     optimizer: torch.optim.Optimizer,
     lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
     callbacks_fn: Callable[..., TrainingCallbacks],
+    prepare_unet: bool = False,
     num_train_epochs: int = 100,
     sample_frequency: int = 20,
     checkpoint_frequency: int = 50,
@@ -523,9 +524,19 @@ def train(
     prior_loss_weight: float = 1.0,
     **kwargs,
 ):
-    unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
-        unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
-    )
+    prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
+
+    if prepare_unet:
+        prep.append(unet)
+
+    prep = accelerator.prepare(*prep)
+
+    if prepare_unet:
+        text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
+    else:
+        text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
+
+        unet.to(accelerator.device, dtype=dtype)
 
     vae.to(accelerator.device, dtype=dtype)
 
-- 
cgit v1.2.3-70-g09d2