summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py30
1 files changed, 7 insertions, 23 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c8f03ea..be4da1a 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18 18
19# from diffusers.models.attention_processor import AttnProcessor 19# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available
21import transformers 20import transformers
22 21
23import numpy as np 22import numpy as np
@@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True)
48hidet.torch.dynamo_config.search_space(0) 47hidet.torch.dynamo_config.search_space(0)
49 48
50 49
51def patch_xformers(dtype):
52 if is_xformers_available():
53 import xformers
54 import xformers.ops
55
56 orig_xformers_memory_efficient_attention = (
57 xformers.ops.memory_efficient_attention
58 )
59
60 def xformers_memory_efficient_attention(
61 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
62 ):
63 return orig_xformers_memory_efficient_attention(
64 query.to(dtype), key.to(dtype), value.to(dtype), **kwargs
65 )
66
67 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
68
69
70def parse_args(): 50def parse_args():
71 parser = argparse.ArgumentParser(description="Simple example of a training script.") 51 parser = argparse.ArgumentParser(description="Simple example of a training script.")
72 parser.add_argument( 52 parser.add_argument(
@@ -224,6 +204,12 @@ def parse_args():
224 help="A collection to filter the dataset.", 204 help="A collection to filter the dataset.",
225 ) 205 )
226 parser.add_argument( 206 parser.add_argument(
207 "--validation_prompts",
208 type=str,
209 nargs="*",
210 help="Prompts for additional validation images",
211 )
212 parser.add_argument(
227 "--seed", type=int, default=None, help="A seed for reproducible training." 213 "--seed", type=int, default=None, help="A seed for reproducible training."
228 ) 214 )
229 parser.add_argument( 215 parser.add_argument(
@@ -476,7 +462,7 @@ def parse_args():
476 parser.add_argument( 462 parser.add_argument(
477 "--sample_steps", 463 "--sample_steps",
478 type=int, 464 type=int,
479 default=10, 465 default=15,
480 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 466 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
481 ) 467 )
482 parser.add_argument( 468 parser.add_argument(
@@ -622,8 +608,6 @@ def main():
622 elif args.mixed_precision == "bf16": 608 elif args.mixed_precision == "bf16":
623 weight_dtype = torch.bfloat16 609 weight_dtype = torch.bfloat16
624 610
625 patch_xformers(weight_dtype)
626
627 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) 611 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
628 612
629 if args.seed is None: 613 if args.seed is None: