diff options
Diffstat (limited to 'dreambooth_plus.py')
| -rw-r--r-- | dreambooth_plus.py | 16 |
1 files changed, 3 insertions, 13 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 9e482b3..7996bc2 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--max_train_steps", | 113 | "--max_train_steps", |
| 114 | type=int, | 114 | type=int, |
| 115 | default=3000, | 115 | default=1600, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -129,13 +129,13 @@ def parse_args(): | |||
| 129 | parser.add_argument( | 129 | parser.add_argument( |
| 130 | "--learning_rate_unet", | 130 | "--learning_rate_unet", |
| 131 | type=float, | 131 | type=float, |
| 132 | default=1e-5, | 132 | default=5e-6, |
| 133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
| 134 | ) | 134 | ) |
| 135 | parser.add_argument( | 135 | parser.add_argument( |
| 136 | "--learning_rate_text", | 136 | "--learning_rate_text", |
| 137 | type=float, | 137 | type=float, |
| 138 | default=1e-4, | 138 | default=5e-4, |
| 139 | help="Initial learning rate (after the potential warmup period) to use.", | 139 | help="Initial learning rate (after the potential warmup period) to use.", |
| 140 | ) | 140 | ) |
| 141 | parser.add_argument( | 141 | parser.add_argument( |
| @@ -222,12 +222,6 @@ def parse_args(): | |||
| 222 | ), | 222 | ), |
| 223 | ) | 223 | ) |
| 224 | parser.add_argument( | 224 | parser.add_argument( |
| 225 | "--local_rank", | ||
| 226 | type=int, | ||
| 227 | default=-1, | ||
| 228 | help="For distributed training: local_rank" | ||
| 229 | ) | ||
| 230 | parser.add_argument( | ||
| 231 | "--sample_frequency", | 225 | "--sample_frequency", |
| 232 | type=int, | 226 | type=int, |
| 233 | default=100, | 227 | default=100, |
| @@ -293,10 +287,6 @@ def parse_args(): | |||
| 293 | args = parser.parse_args( | 287 | args = parser.parse_args( |
| 294 | namespace=argparse.Namespace(**json.load(f)["args"])) | 288 | namespace=argparse.Namespace(**json.load(f)["args"])) |
| 295 | 289 | ||
| 296 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
| 297 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
| 298 | args.local_rank = env_local_rank | ||
| 299 | |||
| 300 | if args.train_data_file is None: | 290 | if args.train_data_file is None: |
| 301 | raise ValueError("You must specify --train_data_file") | 291 | raise ValueError("You must specify --train_data_file") |
| 302 | 292 | ||
