diff options
| author | Volpeon <git@volpeon.ink> | 2022-10-13 09:45:27 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-10-13 09:45:27 +0200 |
| commit | db0996c299fdd559ebf9cd48f9dbe47474ed7b07 (patch) | |
| tree | 0d306c661ed5629e7d69566a82d588aca5ed86a9 /textual_inversion.py | |
| parent | Various updates (diff) | |
| download | textual-inversion-diff-db0996c299fdd559ebf9cd48f9dbe47474ed7b07.tar.gz textual-inversion-diff-db0996c299fdd559ebf9cd48f9dbe47474ed7b07.tar.bz2 textual-inversion-diff-db0996c299fdd559ebf9cd48f9dbe47474ed7b07.zip | |
Added TI+Dreambooth training
Diffstat (limited to 'textual_inversion.py')
| -rw-r--r-- | textual_inversion.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index 3a3741d..181a318 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -111,7 +111,7 @@ def parse_args(): | |||
| 111 | parser.add_argument( | 111 | parser.add_argument( |
| 112 | "--max_train_steps", | 112 | "--max_train_steps", |
| 113 | type=int, | 113 | type=int, |
| 114 | default=3000, | 114 | default=10000, |
| 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 116 | ) | 116 | ) |
| 117 | parser.add_argument( | 117 | parser.add_argument( |
| @@ -128,7 +128,7 @@ def parse_args(): | |||
| 128 | parser.add_argument( | 128 | parser.add_argument( |
| 129 | "--learning_rate", | 129 | "--learning_rate", |
| 130 | type=float, | 130 | type=float, |
| 131 | default=5e-5, | 131 | default=1e-4, |
| 132 | help="Initial learning rate (after the potential warmup period) to use.", | 132 | help="Initial learning rate (after the potential warmup period) to use.", |
| 133 | ) | 133 | ) |
| 134 | parser.add_argument( | 134 | parser.add_argument( |
| @@ -247,6 +247,11 @@ def parse_args(): | |||
| 247 | help="The weight of prior preservation loss." | 247 | help="The weight of prior preservation loss." |
| 248 | ) | 248 | ) |
| 249 | parser.add_argument( | 249 | parser.add_argument( |
| 250 | "--noise_timesteps", | ||
| 251 | type=int, | ||
| 252 | default=1000, | ||
| 253 | ) | ||
| 254 | parser.add_argument( | ||
| 250 | "--resume_from", | 255 | "--resume_from", |
| 251 | type=str, | 256 | type=str, |
| 252 | default=None, | 257 | default=None, |
| @@ -568,7 +573,7 @@ def main(): | |||
| 568 | beta_start=0.00085, | 573 | beta_start=0.00085, |
| 569 | beta_end=0.012, | 574 | beta_end=0.012, |
| 570 | beta_schedule="scaled_linear", | 575 | beta_schedule="scaled_linear", |
| 571 | num_train_timesteps=1000 | 576 | num_train_timesteps=args.noise_timesteps |
| 572 | ) | 577 | ) |
| 573 | 578 | ||
| 574 | def collate_fn(examples): | 579 | def collate_fn(examples): |
