diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 9 |
1 files changed, 1 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 659b84c..0543a35 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -246,12 +246,6 @@ def parse_args(): | |||
246 | ), | 246 | ), |
247 | ) | 247 | ) |
248 | parser.add_argument( | 248 | parser.add_argument( |
249 | "--offset_noise_strength", | ||
250 | type=float, | ||
251 | default=0, | ||
252 | help="Perlin offset noise strength.", | ||
253 | ) | ||
254 | parser.add_argument( | ||
255 | "--input_pertubation", | 249 | "--input_pertubation", |
256 | type=float, | 250 | type=float, |
257 | default=0, | 251 | default=0, |
@@ -496,7 +490,6 @@ def parse_args(): | |||
496 | default=1.0, | 490 | default=1.0, |
497 | help="The weight of prior preservation loss.", | 491 | help="The weight of prior preservation loss.", |
498 | ) | 492 | ) |
499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | 493 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") |
501 | parser.add_argument( | 494 | parser.add_argument( |
502 | "--emb_dropout", | 495 | "--emb_dropout", |
@@ -679,6 +672,7 @@ def main(): | |||
679 | 672 | ||
680 | if args.gradient_checkpointing: | 673 | if args.gradient_checkpointing: |
681 | unet.enable_gradient_checkpointing() | 674 | unet.enable_gradient_checkpointing() |
675 | text_encoder.gradient_checkpointing_enable() | ||
682 | 676 | ||
683 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
684 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
@@ -1074,7 +1068,6 @@ def main(): | |||
1074 | sample_output_dir=dreambooth_sample_output_dir, | 1068 | sample_output_dir=dreambooth_sample_output_dir, |
1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1069 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, |
1076 | sample_frequency=dreambooth_sample_frequency, | 1070 | sample_frequency=dreambooth_sample_frequency, |
1077 | offset_noise_strength=args.offset_noise_strength, | ||
1078 | input_pertubation=args.input_pertubation, | 1071 | input_pertubation=args.input_pertubation, |
1079 | no_val=args.valid_set_size == 0, | 1072 | no_val=args.valid_set_size == 0, |
1080 | avg_loss=avg_loss, | 1073 | avg_loss=avg_loss, |