From db0996c299fdd559ebf9cd48f9dbe47474ed7b07 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Oct 2022 09:45:27 +0200 Subject: Added TI+Dreambooth training --- textual_inversion.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 3a3741d..181a318 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -111,7 +111,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=3000, + default=10000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -128,7 +128,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=5e-5, + default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -246,6 +246,11 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--noise_timesteps", + type=int, + default=1000, + ) parser.add_argument( "--resume_from", type=str, @@ -568,7 +573,7 @@ def main(): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000 + num_train_timesteps=args.noise_timesteps ) def collate_fn(examples): -- cgit v1.2.3-54-g00ecf