summaryrefslogtreecommitdiffstats
path: root/dreambooth_plus.py
diff options
context:
space:
mode:
Diffstat (limited to 'dreambooth_plus.py')
-rw-r--r--dreambooth_plus.py16
1 files changed, 3 insertions, 13 deletions
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 9e482b3..7996bc2 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=3000, 115 default=1600,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -129,13 +129,13 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate_unet", 130 "--learning_rate_unet",
131 type=float, 131 type=float,
132 default=1e-5, 132 default=5e-6,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
136 "--learning_rate_text", 136 "--learning_rate_text",
137 type=float, 137 type=float,
138 default=1e-4, 138 default=5e-4,
139 help="Initial learning rate (after the potential warmup period) to use.", 139 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 140 )
141 parser.add_argument( 141 parser.add_argument(
@@ -222,12 +222,6 @@ def parse_args():
222 ), 222 ),
223 ) 223 )
224 parser.add_argument( 224 parser.add_argument(
225 "--local_rank",
226 type=int,
227 default=-1,
228 help="For distributed training: local_rank"
229 )
230 parser.add_argument(
231 "--sample_frequency", 225 "--sample_frequency",
232 type=int, 226 type=int,
233 default=100, 227 default=100,
@@ -293,10 +287,6 @@ def parse_args():
293 args = parser.parse_args( 287 args = parser.parse_args(
294 namespace=argparse.Namespace(**json.load(f)["args"])) 288 namespace=argparse.Namespace(**json.load(f)["args"]))
295 289
296 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
297 if env_local_rank != -1 and env_local_rank != args.local_rank:
298 args.local_rank = env_local_rank
299
300 if args.train_data_file is None: 290 if args.train_data_file is None:
301 raise ValueError("You must specify --train_data_file") 291 raise ValueError("You must specify --train_data_file")
302 292