summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py9
1 files changed, 1 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 659b84c..0543a35 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -246,12 +246,6 @@ def parse_args():
246 ), 246 ),
247 ) 247 )
248 parser.add_argument( 248 parser.add_argument(
249 "--offset_noise_strength",
250 type=float,
251 default=0,
252 help="Perlin offset noise strength.",
253 )
254 parser.add_argument(
255 "--input_pertubation", 249 "--input_pertubation",
256 type=float, 250 type=float,
257 default=0, 251 default=0,
@@ -496,7 +490,6 @@ def parse_args():
496 default=1.0, 490 default=1.0,
497 help="The weight of prior preservation loss.", 491 help="The weight of prior preservation loss.",
498 ) 492 )
499 parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.")
500 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") 493 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
501 parser.add_argument( 494 parser.add_argument(
502 "--emb_dropout", 495 "--emb_dropout",
@@ -679,6 +672,7 @@ def main():
679 672
680 if args.gradient_checkpointing: 673 if args.gradient_checkpointing:
681 unet.enable_gradient_checkpointing() 674 unet.enable_gradient_checkpointing()
675 text_encoder.gradient_checkpointing_enable()
682 676
683 if len(args.alias_tokens) != 0: 677 if len(args.alias_tokens) != 0:
684 alias_placeholder_tokens = args.alias_tokens[::2] 678 alias_placeholder_tokens = args.alias_tokens[::2]
@@ -1074,7 +1068,6 @@ def main():
1074 sample_output_dir=dreambooth_sample_output_dir, 1068 sample_output_dir=dreambooth_sample_output_dir,
1075 checkpoint_output_dir=dreambooth_checkpoint_output_dir, 1069 checkpoint_output_dir=dreambooth_checkpoint_output_dir,
1076 sample_frequency=dreambooth_sample_frequency, 1070 sample_frequency=dreambooth_sample_frequency,
1077 offset_noise_strength=args.offset_noise_strength,
1078 input_pertubation=args.input_pertubation, 1071 input_pertubation=args.input_pertubation,
1079 no_val=args.valid_set_size == 0, 1072 no_val=args.valid_set_size == 0,
1080 avg_loss=avg_loss, 1073 avg_loss=avg_loss,