diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 |
| commit | 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 (patch) | |
| tree | 6c1f2243475778bb5e9e1725bf3969a5442393d8 /train_dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.gz textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.bz2 textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.zip | |
Fixes
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 30 |
1 files changed, 7 insertions, 23 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index c8f03ea..be4da1a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | 18 | ||
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
| 20 | from diffusers.utils.import_utils import is_xformers_available | ||
| 21 | import transformers | 20 | import transformers |
| 22 | 21 | ||
| 23 | import numpy as np | 22 | import numpy as np |
| @@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
| 48 | hidet.torch.dynamo_config.search_space(0) | 47 | hidet.torch.dynamo_config.search_space(0) |
| 49 | 48 | ||
| 50 | 49 | ||
| 51 | def patch_xformers(dtype): | ||
| 52 | if is_xformers_available(): | ||
| 53 | import xformers | ||
| 54 | import xformers.ops | ||
| 55 | |||
| 56 | orig_xformers_memory_efficient_attention = ( | ||
| 57 | xformers.ops.memory_efficient_attention | ||
| 58 | ) | ||
| 59 | |||
| 60 | def xformers_memory_efficient_attention( | ||
| 61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
| 62 | ): | ||
| 63 | return orig_xformers_memory_efficient_attention( | ||
| 64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
| 65 | ) | ||
| 66 | |||
| 67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 68 | |||
| 69 | |||
| 70 | def parse_args(): | 50 | def parse_args(): |
| 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") | 51 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 72 | parser.add_argument( | 52 | parser.add_argument( |
| @@ -224,6 +204,12 @@ def parse_args(): | |||
| 224 | help="A collection to filter the dataset.", | 204 | help="A collection to filter the dataset.", |
| 225 | ) | 205 | ) |
| 226 | parser.add_argument( | 206 | parser.add_argument( |
| 207 | "--validation_prompts", | ||
| 208 | type=str, | ||
| 209 | nargs="*", | ||
| 210 | help="Prompts for additional validation images", | ||
| 211 | ) | ||
| 212 | parser.add_argument( | ||
| 227 | "--seed", type=int, default=None, help="A seed for reproducible training." | 213 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 228 | ) | 214 | ) |
| 229 | parser.add_argument( | 215 | parser.add_argument( |
| @@ -476,7 +462,7 @@ def parse_args(): | |||
| 476 | parser.add_argument( | 462 | parser.add_argument( |
| 477 | "--sample_steps", | 463 | "--sample_steps", |
| 478 | type=int, | 464 | type=int, |
| 479 | default=10, | 465 | default=15, |
| 480 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 481 | ) | 467 | ) |
| 482 | parser.add_argument( | 468 | parser.add_argument( |
| @@ -622,8 +608,6 @@ def main(): | |||
| 622 | elif args.mixed_precision == "bf16": | 608 | elif args.mixed_precision == "bf16": |
| 623 | weight_dtype = torch.bfloat16 | 609 | weight_dtype = torch.bfloat16 |
| 624 | 610 | ||
| 625 | patch_xformers(weight_dtype) | ||
| 626 | |||
| 627 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 611 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 628 | 612 | ||
| 629 | if args.seed is None: | 613 | if args.seed is None: |
