diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 10 |
1 files changed, 0 insertions, 10 deletions
diff --git a/dreambooth.py b/dreambooth.py index 699313e..072142e 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -216,12 +216,6 @@ def parse_args(): | |||
| 216 | ), | 216 | ), |
| 217 | ) | 217 | ) |
| 218 | parser.add_argument( | 218 | parser.add_argument( |
| 219 | "--local_rank", | ||
| 220 | type=int, | ||
| 221 | default=-1, | ||
| 222 | help="For distributed training: local_rank" | ||
| 223 | ) | ||
| 224 | parser.add_argument( | ||
| 225 | "--sample_frequency", | 219 | "--sample_frequency", |
| 226 | type=int, | 220 | type=int, |
| 227 | default=100, | 221 | default=100, |
| @@ -287,10 +281,6 @@ def parse_args(): | |||
| 287 | args = parser.parse_args( | 281 | args = parser.parse_args( |
| 288 | namespace=argparse.Namespace(**json.load(f)["args"])) | 282 | namespace=argparse.Namespace(**json.load(f)["args"])) |
| 289 | 283 | ||
| 290 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
| 291 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
| 292 | args.local_rank = env_local_rank | ||
| 293 | |||
| 294 | if args.train_data_file is None: | 284 | if args.train_data_file is None: |
| 295 | raise ValueError("You must specify --train_data_file") | 285 | raise ValueError("You must specify --train_data_file") |
| 296 | 286 | ||
