summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py12
-rw-r--r--train_lora.py25
-rw-r--r--train_ti.py12
-rw-r--r--training/strategy/lora.py23
4 files changed, 35 insertions, 37 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5a4c47b..a29c507 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -442,6 +442,12 @@ def main():
442 mixed_precision=args.mixed_precision 442 mixed_precision=args.mixed_precision
443 ) 443 )
444 444
445 weight_dtype = torch.float32
446 if args.mixed_precision == "fp16":
447 weight_dtype = torch.float16
448 elif args.mixed_precision == "bf16":
449 weight_dtype = torch.bfloat16
450
445 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 451 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
446 452
447 if args.seed is None: 453 if args.seed is None:
@@ -495,12 +501,6 @@ def main():
495 else: 501 else:
496 optimizer_class = torch.optim.AdamW 502 optimizer_class = torch.optim.AdamW
497 503
498 weight_dtype = torch.float32
499 if args.mixed_precision == "fp16":
500 weight_dtype = torch.float16
501 elif args.mixed_precision == "bf16":
502 weight_dtype = torch.bfloat16
503
504 trainer = partial( 504 trainer = partial(
505 train, 505 train,
506 accelerator=accelerator, 506 accelerator=accelerator,
diff --git a/train_lora.py b/train_lora.py
index b273ae1..ab1753b 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers 15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor 16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor
17 17
18from util import load_config, load_embeddings_from_dir 18from util import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
@@ -178,6 +178,11 @@ def parse_args():
178 help="Number of updates steps to accumulate before performing a backward/update pass.", 178 help="Number of updates steps to accumulate before performing a backward/update pass.",
179 ) 179 )
180 parser.add_argument( 180 parser.add_argument(
181 "--gradient_checkpointing",
182 action="store_true",
183 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
184 )
185 parser.add_argument(
181 "--find_lr", 186 "--find_lr",
182 action="store_true", 187 action="store_true",
183 help="Automatically find a learning rate (no training).", 188 help="Automatically find a learning rate (no training).",
@@ -402,6 +407,12 @@ def main():
402 mixed_precision=args.mixed_precision 407 mixed_precision=args.mixed_precision
403 ) 408 )
404 409
410 weight_dtype = torch.float32
411 if args.mixed_precision == "fp16":
412 weight_dtype = torch.float16
413 elif args.mixed_precision == "bf16":
414 weight_dtype = torch.bfloat16
415
405 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 416 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
406 417
407 if args.seed is None: 418 if args.seed is None:
@@ -418,6 +429,12 @@ def main():
418 vae.set_use_memory_efficient_attention_xformers(True) 429 vae.set_use_memory_efficient_attention_xformers(True)
419 unet.enable_xformers_memory_efficient_attention() 430 unet.enable_xformers_memory_efficient_attention()
420 431
432 if args.gradient_checkpointing:
433 unet.enable_gradient_checkpointing()
434
435 unet.to(accelerator.device, dtype=weight_dtype)
436 text_encoder.to(accelerator.device, dtype=weight_dtype)
437
421 lora_attn_procs = {} 438 lora_attn_procs = {}
422 for name in unet.attn_processors.keys(): 439 for name in unet.attn_processors.keys():
423 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 440 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
@@ -467,12 +484,6 @@ def main():
467 else: 484 else:
468 optimizer_class = torch.optim.AdamW 485 optimizer_class = torch.optim.AdamW
469 486
470 weight_dtype = torch.float32
471 if args.mixed_precision == "fp16":
472 weight_dtype = torch.float16
473 elif args.mixed_precision == "bf16":
474 weight_dtype = torch.bfloat16
475
476 trainer = partial( 487 trainer = partial(
477 train, 488 train,
478 accelerator=accelerator, 489 accelerator=accelerator,
diff --git a/train_ti.py b/train_ti.py
index 56f9e97..2840def 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -513,6 +513,12 @@ def main():
513 mixed_precision=args.mixed_precision 513 mixed_precision=args.mixed_precision
514 ) 514 )
515 515
516 weight_dtype = torch.float32
517 if args.mixed_precision == "fp16":
518 weight_dtype = torch.float16
519 elif args.mixed_precision == "bf16":
520 weight_dtype = torch.bfloat16
521
516 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 522 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
517 523
518 if args.seed is None: 524 if args.seed is None:
@@ -564,12 +570,6 @@ def main():
564 else: 570 else:
565 optimizer_class = torch.optim.AdamW 571 optimizer_class = torch.optim.AdamW
566 572
567 weight_dtype = torch.float32
568 if args.mixed_precision == "fp16":
569 weight_dtype = torch.float16
570 elif args.mixed_precision == "bf16":
571 weight_dtype = torch.bfloat16
572
573 checkpoint_output_dir = output_dir.joinpath("checkpoints") 573 checkpoint_output_dir = output_dir.joinpath("checkpoints")
574 574
575 trainer = partial( 575 trainer = partial(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 92abaa6..bc10e58 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -89,20 +89,14 @@ def lora_strategy_callbacks(
89 @torch.no_grad() 89 @torch.no_grad()
90 def on_checkpoint(step, postfix): 90 def on_checkpoint(step, postfix):
91 print(f"Saving checkpoint for step {step}...") 91 print(f"Saving checkpoint for step {step}...")
92 orig_unet_dtype = unet.dtype 92
93 unet.to(dtype=torch.float32) 93 unet_ = accelerator.unwrap_model(unet)
94 unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) 94 unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
95 unet.to(dtype=orig_unet_dtype) 95 del unet_
96 96
97 @torch.no_grad() 97 @torch.no_grad()
98 def on_sample(step): 98 def on_sample(step):
99 orig_unet_dtype = unet.dtype
100 unet.to(dtype=weight_dtype)
101 save_samples_(step=step) 99 save_samples_(step=step)
102 unet.to(dtype=orig_unet_dtype)
103
104 if torch.cuda.is_available():
105 torch.cuda.empty_cache()
106 100
107 return TrainingCallbacks( 101 return TrainingCallbacks(
108 on_prepare=on_prepare, 102 on_prepare=on_prepare,
@@ -126,16 +120,9 @@ def lora_prepare(
126 lora_layers: AttnProcsLayers, 120 lora_layers: AttnProcsLayers,
127 **kwargs 121 **kwargs
128): 122):
129 weight_dtype = torch.float32
130 if accelerator.state.mixed_precision == "fp16":
131 weight_dtype = torch.float16
132 elif accelerator.state.mixed_precision == "bf16":
133 weight_dtype = torch.bfloat16
134
135 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 123 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
136 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) 124 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
137 unet.to(accelerator.device, dtype=weight_dtype) 125
138 text_encoder.to(accelerator.device, dtype=weight_dtype)
139 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} 126 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}
140 127
141 128