From 515f0f1fdc9a76bf63bd746c291dcfec7fc747fb Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 13 Oct 2022 21:11:53 +0200 Subject: Added support for Aesthetic Gradients --- dreambooth_plus.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index 9e482b3..7996bc2 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=3000, + default=1600, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -129,13 +129,13 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=1e-5, + default=5e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--learning_rate_text", type=float, - default=1e-4, + default=5e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -221,12 +221,6 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument( - "--local_rank", - type=int, - default=-1, - help="For distributed training: local_rank" - ) parser.add_argument( "--sample_frequency", type=int, @@ -293,10 +287,6 @@ def parse_args(): args = parser.parse_args( namespace=argparse.Namespace(**json.load(f)["args"])) - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - if args.train_data_file is None: raise ValueError("You must specify --train_data_file") -- cgit v1.2.3-54-g00ecf