diff options
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r-- | dreambooth_plus.py | 16 |
1 files changed, 3 insertions, 13 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 9e482b3..7996bc2 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py | |||
@@ -112,7 +112,7 @@ def parse_args(): | |||
112 | parser.add_argument( | 112 | parser.add_argument( |
113 | "--max_train_steps", | 113 | "--max_train_steps", |
114 | type=int, | 114 | type=int, |
115 | default=3000, | 115 | default=1600, |
116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
117 | ) | 117 | ) |
118 | parser.add_argument( | 118 | parser.add_argument( |
@@ -129,13 +129,13 @@ def parse_args(): | |||
129 | parser.add_argument( | 129 | parser.add_argument( |
130 | "--learning_rate_unet", | 130 | "--learning_rate_unet", |
131 | type=float, | 131 | type=float, |
132 | default=1e-5, | 132 | default=5e-6, |
133 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
134 | ) | 134 | ) |
135 | parser.add_argument( | 135 | parser.add_argument( |
136 | "--learning_rate_text", | 136 | "--learning_rate_text", |
137 | type=float, | 137 | type=float, |
138 | default=1e-4, | 138 | default=5e-4, |
139 | help="Initial learning rate (after the potential warmup period) to use.", | 139 | help="Initial learning rate (after the potential warmup period) to use.", |
140 | ) | 140 | ) |
141 | parser.add_argument( | 141 | parser.add_argument( |
@@ -222,12 +222,6 @@ def parse_args(): | |||
222 | ), | 222 | ), |
223 | ) | 223 | ) |
224 | parser.add_argument( | 224 | parser.add_argument( |
225 | "--local_rank", | ||
226 | type=int, | ||
227 | default=-1, | ||
228 | help="For distributed training: local_rank" | ||
229 | ) | ||
230 | parser.add_argument( | ||
231 | "--sample_frequency", | 225 | "--sample_frequency", |
232 | type=int, | 226 | type=int, |
233 | default=100, | 227 | default=100, |
@@ -293,10 +287,6 @@ def parse_args(): | |||
293 | args = parser.parse_args( | 287 | args = parser.parse_args( |
294 | namespace=argparse.Namespace(**json.load(f)["args"])) | 288 | namespace=argparse.Namespace(**json.load(f)["args"])) |
295 | 289 | ||
296 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
297 | if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
298 | args.local_rank = env_local_rank | ||
299 | |||
300 | if args.train_data_file is None: | 290 | if args.train_data_file is None: |
301 | raise ValueError("You must specify --train_data_file") | 291 | raise ValueError("You must specify --train_data_file") |
302 | 292 | ||