summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py2
-rw-r--r--train_lora.py23
-rw-r--r--train_ti.py8
-rw-r--r--training/functional.py6
-rw-r--r--training/strategy/lora.py25
5 files changed, 37 insertions, 27 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index f426de1..4505a2a 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -394,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
394 sag_scale: float = 0.75, 394 sag_scale: float = 0.75,
395 eta: float = 0.0, 395 eta: float = 0.0,
396 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 396 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
397 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise", 397 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None,
398 output_type: str = "pil", 398 output_type: str = "pil",
399 return_dict: bool = True, 399 return_dict: bool = True,
400 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 400 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
diff --git a/train_lora.py b/train_lora.py
index 6e72376..e65e7be 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -13,7 +13,7 @@ from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed 13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify 14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers 15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor 16from diffusers.models.cross_attention import LoRACrossAttnProcessor
17 17
18from util.files import load_config, load_embeddings_from_dir 18from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 19from data.csv import VlpnDataModule, keyword_filter
@@ -292,6 +292,12 @@ def parse_args():
292 ), 292 ),
293 ) 293 )
294 parser.add_argument( 294 parser.add_argument(
295 "--lora_rank",
296 type=int,
297 default=256,
298 help="LoRA rank.",
299 )
300 parser.add_argument(
295 "--sample_frequency", 301 "--sample_frequency",
296 type=int, 302 type=int,
297 default=1, 303 default=1,
@@ -420,10 +426,6 @@ def main():
420 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 426 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
421 args.pretrained_model_name_or_path) 427 args.pretrained_model_name_or_path)
422 428
423 vae.enable_slicing()
424 vae.set_use_memory_efficient_attention_xformers(True)
425 unet.enable_xformers_memory_efficient_attention()
426
427 unet.to(accelerator.device, dtype=weight_dtype) 429 unet.to(accelerator.device, dtype=weight_dtype)
428 text_encoder.to(accelerator.device, dtype=weight_dtype) 430 text_encoder.to(accelerator.device, dtype=weight_dtype)
429 431
@@ -439,11 +441,18 @@ def main():
439 block_id = int(name[len("down_blocks.")]) 441 block_id = int(name[len("down_blocks.")])
440 hidden_size = unet.config.block_out_channels[block_id] 442 hidden_size = unet.config.block_out_channels[block_id]
441 443
442 lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( 444 lora_attn_procs[name] = LoRACrossAttnProcessor(
443 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim 445 hidden_size=hidden_size,
446 cross_attention_dim=cross_attention_dim,
447 rank=args.lora_rank
444 ) 448 )
445 449
446 unet.set_attn_processor(lora_attn_procs) 450 unet.set_attn_processor(lora_attn_procs)
451
452 vae.enable_slicing()
453 vae.set_use_memory_efficient_attention_xformers(True)
454 unet.enable_xformers_memory_efficient_attention()
455
447 lora_layers = AttnProcsLayers(unet.attn_processors) 456 lora_layers = AttnProcsLayers(unet.attn_processors)
448 457
449 if args.embeddings_dir is not None: 458 if args.embeddings_dir is not None:
diff --git a/train_ti.py b/train_ti.py
index b9d6e56..81938c8 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -476,13 +476,10 @@ def parse_args():
476 if len(args.placeholder_tokens) != len(args.initializer_tokens): 476 if len(args.placeholder_tokens) != len(args.initializer_tokens):
477 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 477 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
478 478
479 if args.num_vectors is None:
480 args.num_vectors = 1
481
482 if isinstance(args.num_vectors, int): 479 if isinstance(args.num_vectors, int):
483 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 480 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
484 481
485 if len(args.placeholder_tokens) != len(args.num_vectors): 482 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors):
486 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
487 484
488 if args.sequential: 485 if args.sequential:
@@ -491,6 +488,9 @@ def parse_args():
491 488
492 if len(args.placeholder_tokens) != len(args.train_data_template): 489 if len(args.placeholder_tokens) != len(args.train_data_template):
493 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 490 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
491
492 if args.num_vectors is None:
493 args.num_vectors = [None] * len(args.placeholder_tokens)
494 else: 494 else:
495 if isinstance(args.train_data_template, list): 495 if isinstance(args.train_data_template, list):
496 raise ValueError("--train_data_template can't be a list in simultaneous mode") 496 raise ValueError("--train_data_template can't be a list in simultaneous mode")
diff --git a/training/functional.py b/training/functional.py
index 27a43c2..4565612 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -231,12 +231,16 @@ def add_placeholder_tokens(
231 embeddings: ManagedCLIPTextEmbeddings, 231 embeddings: ManagedCLIPTextEmbeddings,
232 placeholder_tokens: list[str], 232 placeholder_tokens: list[str],
233 initializer_tokens: list[str], 233 initializer_tokens: list[str],
234 num_vectors: Union[list[int], int] 234 num_vectors: Optional[Union[list[int], int]] = None,
235): 235):
236 initializer_token_ids = [ 236 initializer_token_ids = [
237 tokenizer.encode(token, add_special_tokens=False) 237 tokenizer.encode(token, add_special_tokens=False)
238 for token in initializer_tokens 238 for token in initializer_tokens
239 ] 239 ]
240
241 if num_vectors is None:
242 num_vectors = [len(ids) for ids in initializer_token_ids]
243
240 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) 244 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
241 245
242 embeddings.resize(len(tokenizer)) 246 embeddings.resize(len(tokenizer))
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index ccec215..cab5e4c 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -11,10 +11,7 @@ from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.loaders import AttnProcsLayers 12from diffusers.loaders import AttnProcsLayers
13 13
14from slugify import slugify
15
16from models.clip.tokenizer import MultiCLIPTokenizer 14from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 15from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 16
20 17
@@ -41,16 +38,9 @@ def lora_strategy_callbacks(
41 sample_output_dir.mkdir(parents=True, exist_ok=True) 38 sample_output_dir.mkdir(parents=True, exist_ok=True)
42 checkpoint_output_dir.mkdir(parents=True, exist_ok=True) 39 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
43 40
44 weight_dtype = torch.float32
45 if accelerator.state.mixed_precision == "fp16":
46 weight_dtype = torch.float16
47 elif accelerator.state.mixed_precision == "bf16":
48 weight_dtype = torch.bfloat16
49
50 save_samples_ = partial( 41 save_samples_ = partial(
51 save_samples, 42 save_samples,
52 accelerator=accelerator, 43 accelerator=accelerator,
53 unet=unet,
54 text_encoder=text_encoder, 44 text_encoder=text_encoder,
55 tokenizer=tokenizer, 45 tokenizer=tokenizer,
56 vae=vae, 46 vae=vae,
@@ -83,20 +73,27 @@ def lora_strategy_callbacks(
83 yield 73 yield
84 74
85 def on_before_optimize(lr: float, epoch: int): 75 def on_before_optimize(lr: float, epoch: int):
86 if accelerator.sync_gradients: 76 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
87 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
88 77
89 @torch.no_grad() 78 @torch.no_grad()
90 def on_checkpoint(step, postfix): 79 def on_checkpoint(step, postfix):
91 print(f"Saving checkpoint for step {step}...") 80 print(f"Saving checkpoint for step {step}...")
92 81
93 unet_ = accelerator.unwrap_model(unet, False) 82 unet_ = accelerator.unwrap_model(unet, False)
94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") 83 unet_.save_attn_procs(
84 checkpoint_output_dir / f"{step}_{postfix}",
85 safe_serialization=True
86 )
95 del unet_ 87 del unet_
96 88
97 @torch.no_grad() 89 @torch.no_grad()
98 def on_sample(step): 90 def on_sample(step):
99 save_samples_(step=step) 91 unet_ = accelerator.unwrap_model(unet, False)
92 save_samples_(step=step, unet=unet_)
93 del unet_
94
95 if torch.cuda.is_available():
96 torch.cuda.empty_cache()
100 97
101 return TrainingCallbacks( 98 return TrainingCallbacks(
102 on_prepare=on_prepare, 99 on_prepare=on_prepare,