From 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 16:26:22 +0200 Subject: Fixes --- train_dreambooth.py | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index c8f03ea..be4da1a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -17,7 +17,6 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed # from diffusers.models.attention_processor import AttnProcessor -from diffusers.utils.import_utils import is_xformers_available import transformers import numpy as np @@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) -def patch_xformers(dtype): - if is_xformers_available(): - import xformers - import xformers.ops - - orig_xformers_memory_efficient_attention = ( - xformers.ops.memory_efficient_attention - ) - - def xformers_memory_efficient_attention( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs - ): - return orig_xformers_memory_efficient_attention( - query.to(dtype), key.to(dtype), value.to(dtype), **kwargs - ) - - xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention - - def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -223,6 +203,12 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) + parser.add_argument( + "--validation_prompts", + type=str, + nargs="*", + help="Prompts for additional validation images", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -476,7 +462,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=10, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -622,8 +608,6 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - patch_xformers(weight_dtype) - logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: -- cgit v1.2.3-54-g00ecf