From 2ddfbd65e482fa2361e8ba41b657656f825c9143 Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Tue, 18 Oct 2022 18:08:32 +0200
Subject: Adapted other scripts for new prompt processing

---
 dreambooth.py        | 17 ++++++++++-------
 dreambooth_plus.py   |  6 +++---
 textual_inversion.py | 17 ++++++++++-------
 3 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/dreambooth.py b/dreambooth.py
index 770ad38..9786e0f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -25,6 +25,7 @@ from slugify import slugify
 from schedulers.scheduling_euler_a import EulerAScheduler
 from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
 from data.csv import CSVDataModule
+from models.clip.prompt import PromptProcessor
 
 logger = get_logger(__name__)
 
@@ -141,7 +142,7 @@ def parse_args():
     parser.add_argument(
         "--lr_scheduler",
         type=str,
-        default="cosine",
+        default="cosine_with_restarts",
         help=(
             'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
             ' "constant", "constant_with_warmup"]'
@@ -494,6 +495,8 @@ def main():
         device=accelerator.device
     ) if args.use_ema else None
 
+    prompt_processor = PromptProcessor(tokenizer, text_encoder)
+
     if args.gradient_checkpointing:
         unet.enable_gradient_checkpointing()
 
@@ -557,7 +560,7 @@ def main():
         pixel_values = torch.stack(pixel_values)
         pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
 
-        input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+        input_ids = prompt_processor.unify_input_ids(input_ids)
 
         batch = {
             "prompts": prompts,
@@ -570,7 +573,7 @@ def main():
     datamodule = CSVDataModule(
         data_file=args.train_data_file,
         batch_size=args.train_batch_size,
-        tokenizer=tokenizer,
+        prompt_processor=prompt_processor,
         instance_identifier=args.instance_identifier,
         class_identifier=args.class_identifier,
         class_subdir="cls",
@@ -641,8 +644,8 @@ def main():
             optimizer=optimizer,
             num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
             num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-            num_cycles=args.lr_cycles or math.ceil(
-                ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
+            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
+                ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
         )
     else:
         lr_scheduler = get_scheduler(
@@ -756,7 +759,7 @@ def main():
                     noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 
                     # Get the text embedding for conditioning
-                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
 
                     # Predict the noise residual
                     noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -832,7 +835,7 @@ def main():
 
                     noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 
-                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
 
                     noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
 
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index fa3a22b..06ff45b 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -125,7 +125,7 @@ def parse_args():
     parser.add_argument(
         "--max_train_steps",
         type=int,
-        default=1400,
+        default=2400,
         help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
     )
     parser.add_argument(
@@ -752,8 +752,8 @@ def main():
             optimizer=optimizer,
             num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
             num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-            num_cycles=args.lr_cycles or math.ceil(
-                ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
+            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
+                ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
         )
     else:
         lr_scheduler = get_scheduler(
diff --git a/textual_inversion.py b/textual_inversion.py
index 69d9c7f..8f266e0 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,6 +25,7 @@ from slugify import slugify
 from schedulers.scheduling_euler_a import EulerAScheduler
 from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
 from data.csv import CSVDataModule
+from models.clip.prompt import PromptProcessor
 
 logger = get_logger(__name__)
 
@@ -152,7 +153,7 @@ def parse_args():
     parser.add_argument(
         "--lr_scheduler",
         type=str,
-        default="cosine",
+        default="cosine_with_restarts",
         help=(
             'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
             ' "constant", "constant_with_warmup"]'
@@ -516,6 +517,8 @@ def main():
     unet = UNet2DConditionModel.from_pretrained(
         args.pretrained_model_name_or_path, subfolder='unet')
 
+    prompt_processor = PromptProcessor(tokenizer, text_encoder)
+
     if args.gradient_checkpointing:
         unet.enable_gradient_checkpointing()
 
@@ -594,7 +597,7 @@ def main():
         pixel_values = torch.stack(pixel_values)
         pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
 
-        input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+        input_ids = prompt_processor.unify_input_ids(input_ids)
 
         batch = {
             "prompts": prompts,
@@ -607,7 +610,7 @@ def main():
     datamodule = CSVDataModule(
         data_file=args.train_data_file,
         batch_size=args.train_batch_size,
-        tokenizer=tokenizer,
+        prompt_processor=prompt_processor,
         instance_identifier=args.instance_identifier,
         class_identifier=args.class_identifier,
         class_subdir="cls",
@@ -678,8 +681,8 @@ def main():
             optimizer=optimizer,
             num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
             num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
-            num_cycles=args.lr_cycles or math.ceil(
-                ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch) / 2),
+            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
+                ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
         )
     else:
         lr_scheduler = get_scheduler(
@@ -794,7 +797,7 @@ def main():
                     noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 
                     # Get the text embedding for conditioning
-                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
 
                     # Predict the noise residual
                     noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -885,7 +888,7 @@ def main():
 
                     noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 
-                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
 
                     noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
 
-- 
cgit v1.2.3-70-g09d2