summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 3a3741d..181a318 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -111,7 +111,7 @@ def parse_args():
111 parser.add_argument( 111 parser.add_argument(
112 "--max_train_steps", 112 "--max_train_steps",
113 type=int, 113 type=int,
114 default=3000, 114 default=10000,
115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
116 ) 116 )
117 parser.add_argument( 117 parser.add_argument(
@@ -128,7 +128,7 @@ def parse_args():
128 parser.add_argument( 128 parser.add_argument(
129 "--learning_rate", 129 "--learning_rate",
130 type=float, 130 type=float,
131 default=5e-5, 131 default=1e-4,
132 help="Initial learning rate (after the potential warmup period) to use.", 132 help="Initial learning rate (after the potential warmup period) to use.",
133 ) 133 )
134 parser.add_argument( 134 parser.add_argument(
@@ -247,6 +247,11 @@ def parse_args():
247 help="The weight of prior preservation loss." 247 help="The weight of prior preservation loss."
248 ) 248 )
249 parser.add_argument( 249 parser.add_argument(
250 "--noise_timesteps",
251 type=int,
252 default=1000,
253 )
254 parser.add_argument(
250 "--resume_from", 255 "--resume_from",
251 type=str, 256 type=str,
252 default=None, 257 default=None,
@@ -568,7 +573,7 @@ def main():
568 beta_start=0.00085, 573 beta_start=0.00085,
569 beta_end=0.012, 574 beta_end=0.012,
570 beta_schedule="scaled_linear", 575 beta_schedule="scaled_linear",
571 num_train_timesteps=1000 576 num_train_timesteps=args.noise_timesteps
572 ) 577 )
573 578
574 def collate_fn(examples): 579 def collate_fn(examples):