diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 3a3741d..181a318 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -111,7 +111,7 @@ def parse_args(): | |||
111 | parser.add_argument( | 111 | parser.add_argument( |
112 | "--max_train_steps", | 112 | "--max_train_steps", |
113 | type=int, | 113 | type=int, |
114 | default=3000, | 114 | default=10000, |
115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
116 | ) | 116 | ) |
117 | parser.add_argument( | 117 | parser.add_argument( |
@@ -128,7 +128,7 @@ def parse_args(): | |||
128 | parser.add_argument( | 128 | parser.add_argument( |
129 | "--learning_rate", | 129 | "--learning_rate", |
130 | type=float, | 130 | type=float, |
131 | default=5e-5, | 131 | default=1e-4, |
132 | help="Initial learning rate (after the potential warmup period) to use.", | 132 | help="Initial learning rate (after the potential warmup period) to use.", |
133 | ) | 133 | ) |
134 | parser.add_argument( | 134 | parser.add_argument( |
@@ -247,6 +247,11 @@ def parse_args(): | |||
247 | help="The weight of prior preservation loss." | 247 | help="The weight of prior preservation loss." |
248 | ) | 248 | ) |
249 | parser.add_argument( | 249 | parser.add_argument( |
250 | "--noise_timesteps", | ||
251 | type=int, | ||
252 | default=1000, | ||
253 | ) | ||
254 | parser.add_argument( | ||
250 | "--resume_from", | 255 | "--resume_from", |
251 | type=str, | 256 | type=str, |
252 | default=None, | 257 | default=None, |
@@ -568,7 +573,7 @@ def main(): | |||
568 | beta_start=0.00085, | 573 | beta_start=0.00085, |
569 | beta_end=0.012, | 574 | beta_end=0.012, |
570 | beta_schedule="scaled_linear", | 575 | beta_schedule="scaled_linear", |
571 | num_train_timesteps=1000 | 576 | num_train_timesteps=args.noise_timesteps |
572 | ) | 577 | ) |
573 | 578 | ||
574 | def collate_fn(examples): | 579 | def collate_fn(examples): |