diff options
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 9 |
1 files changed, 1 insertions, 8 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 659b84c..0543a35 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -246,12 +246,6 @@ def parse_args(): | |||
| 246 | ), | 246 | ), |
| 247 | ) | 247 | ) |
| 248 | parser.add_argument( | 248 | parser.add_argument( |
| 249 | "--offset_noise_strength", | ||
| 250 | type=float, | ||
| 251 | default=0, | ||
| 252 | help="Perlin offset noise strength.", | ||
| 253 | ) | ||
| 254 | parser.add_argument( | ||
| 255 | "--input_pertubation", | 249 | "--input_pertubation", |
| 256 | type=float, | 250 | type=float, |
| 257 | default=0, | 251 | default=0, |
| @@ -496,7 +490,6 @@ def parse_args(): | |||
| 496 | default=1.0, | 490 | default=1.0, |
| 497 | help="The weight of prior preservation loss.", | 491 | help="The weight of prior preservation loss.", |
| 498 | ) | 492 | ) |
| 499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
| 500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | 493 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") |
| 501 | parser.add_argument( | 494 | parser.add_argument( |
| 502 | "--emb_dropout", | 495 | "--emb_dropout", |
| @@ -679,6 +672,7 @@ def main(): | |||
| 679 | 672 | ||
| 680 | if args.gradient_checkpointing: | 673 | if args.gradient_checkpointing: |
| 681 | unet.enable_gradient_checkpointing() | 674 | unet.enable_gradient_checkpointing() |
| 675 | text_encoder.gradient_checkpointing_enable() | ||
| 682 | 676 | ||
| 683 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
| 684 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
| @@ -1074,7 +1068,6 @@ def main(): | |||
| 1074 | sample_output_dir=dreambooth_sample_output_dir, | 1068 | sample_output_dir=dreambooth_sample_output_dir, |
| 1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1069 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, |
| 1076 | sample_frequency=dreambooth_sample_frequency, | 1070 | sample_frequency=dreambooth_sample_frequency, |
| 1077 | offset_noise_strength=args.offset_noise_strength, | ||
| 1078 | input_pertubation=args.input_pertubation, | 1071 | input_pertubation=args.input_pertubation, |
| 1079 | no_val=args.valid_set_size == 0, | 1072 | no_val=args.valid_set_size == 0, |
| 1080 | avg_loss=avg_loss, | 1073 | avg_loss=avg_loss, |
