diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 7 insertions, 23 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c8f03ea..be4da1a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | 18 | ||
19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
20 | from diffusers.utils.import_utils import is_xformers_available | ||
21 | import transformers | 20 | import transformers |
22 | 21 | ||
23 | import numpy as np | 22 | import numpy as np |
@@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
48 | hidet.torch.dynamo_config.search_space(0) | 47 | hidet.torch.dynamo_config.search_space(0) |
49 | 48 | ||
50 | 49 | ||
51 | def patch_xformers(dtype): | ||
52 | if is_xformers_available(): | ||
53 | import xformers | ||
54 | import xformers.ops | ||
55 | |||
56 | orig_xformers_memory_efficient_attention = ( | ||
57 | xformers.ops.memory_efficient_attention | ||
58 | ) | ||
59 | |||
60 | def xformers_memory_efficient_attention( | ||
61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
62 | ): | ||
63 | return orig_xformers_memory_efficient_attention( | ||
64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
65 | ) | ||
66 | |||
67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
68 | |||
69 | |||
70 | def parse_args(): | 50 | def parse_args(): |
71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") | 51 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
72 | parser.add_argument( | 52 | parser.add_argument( |
@@ -224,6 +204,12 @@ def parse_args(): | |||
224 | help="A collection to filter the dataset.", | 204 | help="A collection to filter the dataset.", |
225 | ) | 205 | ) |
226 | parser.add_argument( | 206 | parser.add_argument( |
207 | "--validation_prompts", | ||
208 | type=str, | ||
209 | nargs="*", | ||
210 | help="Prompts for additional validation images", | ||
211 | ) | ||
212 | parser.add_argument( | ||
227 | "--seed", type=int, default=None, help="A seed for reproducible training." | 213 | "--seed", type=int, default=None, help="A seed for reproducible training." |
228 | ) | 214 | ) |
229 | parser.add_argument( | 215 | parser.add_argument( |
@@ -476,7 +462,7 @@ def parse_args(): | |||
476 | parser.add_argument( | 462 | parser.add_argument( |
477 | "--sample_steps", | 463 | "--sample_steps", |
478 | type=int, | 464 | type=int, |
479 | default=10, | 465 | default=15, |
480 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
481 | ) | 467 | ) |
482 | parser.add_argument( | 468 | parser.add_argument( |
@@ -622,8 +608,6 @@ def main(): | |||
622 | elif args.mixed_precision == "bf16": | 608 | elif args.mixed_precision == "bf16": |
623 | weight_dtype = torch.bfloat16 | 609 | weight_dtype = torch.bfloat16 |
624 | 610 | ||
625 | patch_xformers(weight_dtype) | ||
626 | |||
627 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 611 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
628 | 612 | ||
629 | if args.seed is None: | 613 | if args.seed is None: |