diff options
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 2 | ||||
| -rw-r--r-- | train_lora.py | 23 | ||||
| -rw-r--r-- | train_ti.py | 8 | ||||
| -rw-r--r-- | training/functional.py | 6 | ||||
| -rw-r--r-- | training/strategy/lora.py | 25 |
5 files changed, 37 insertions, 27 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f426de1..4505a2a 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -394,7 +394,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 394 | sag_scale: float = 0.75, | 394 | sag_scale: float = 0.75, |
| 395 | eta: float = 0.0, | 395 | eta: float = 0.0, |
| 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 396 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 397 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = "noise", | 397 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None, |
| 398 | output_type: str = "pil", | 398 | output_type: str = "pil", |
| 399 | return_dict: bool = True, | 399 | return_dict: bool = True, |
| 400 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 400 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
diff --git a/train_lora.py b/train_lora.py index 6e72376..e65e7be 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
| 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor, LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor |
| 17 | 17 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 18 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -292,6 +292,12 @@ def parse_args(): | |||
| 292 | ), | 292 | ), |
| 293 | ) | 293 | ) |
| 294 | parser.add_argument( | 294 | parser.add_argument( |
| 295 | "--lora_rank", | ||
| 296 | type=int, | ||
| 297 | default=256, | ||
| 298 | help="LoRA rank.", | ||
| 299 | ) | ||
| 300 | parser.add_argument( | ||
| 295 | "--sample_frequency", | 301 | "--sample_frequency", |
| 296 | type=int, | 302 | type=int, |
| 297 | default=1, | 303 | default=1, |
| @@ -420,10 +426,6 @@ def main(): | |||
| 420 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 426 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 421 | args.pretrained_model_name_or_path) | 427 | args.pretrained_model_name_or_path) |
| 422 | 428 | ||
| 423 | vae.enable_slicing() | ||
| 424 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 425 | unet.enable_xformers_memory_efficient_attention() | ||
| 426 | |||
| 427 | unet.to(accelerator.device, dtype=weight_dtype) | 429 | unet.to(accelerator.device, dtype=weight_dtype) |
| 428 | text_encoder.to(accelerator.device, dtype=weight_dtype) | 430 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
| 429 | 431 | ||
| @@ -439,11 +441,18 @@ def main(): | |||
| 439 | block_id = int(name[len("down_blocks.")]) | 441 | block_id = int(name[len("down_blocks.")]) |
| 440 | hidden_size = unet.config.block_out_channels[block_id] | 442 | hidden_size = unet.config.block_out_channels[block_id] |
| 441 | 443 | ||
| 442 | lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( | 444 | lora_attn_procs[name] = LoRACrossAttnProcessor( |
| 443 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | 445 | hidden_size=hidden_size, |
| 446 | cross_attention_dim=cross_attention_dim, | ||
| 447 | rank=args.lora_rank | ||
| 444 | ) | 448 | ) |
| 445 | 449 | ||
| 446 | unet.set_attn_processor(lora_attn_procs) | 450 | unet.set_attn_processor(lora_attn_procs) |
| 451 | |||
| 452 | vae.enable_slicing() | ||
| 453 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 454 | unet.enable_xformers_memory_efficient_attention() | ||
| 455 | |||
| 447 | lora_layers = AttnProcsLayers(unet.attn_processors) | 456 | lora_layers = AttnProcsLayers(unet.attn_processors) |
| 448 | 457 | ||
| 449 | if args.embeddings_dir is not None: | 458 | if args.embeddings_dir is not None: |
diff --git a/train_ti.py b/train_ti.py index b9d6e56..81938c8 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -476,13 +476,10 @@ def parse_args(): | |||
| 476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 476 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 477 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") |
| 478 | 478 | ||
| 479 | if args.num_vectors is None: | ||
| 480 | args.num_vectors = 1 | ||
| 481 | |||
| 482 | if isinstance(args.num_vectors, int): | 479 | if isinstance(args.num_vectors, int): |
| 483 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 480 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 484 | 481 | ||
| 485 | if len(args.placeholder_tokens) != len(args.num_vectors): | 482 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): |
| 486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 483 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 487 | 484 | ||
| 488 | if args.sequential: | 485 | if args.sequential: |
| @@ -491,6 +488,9 @@ def parse_args(): | |||
| 491 | 488 | ||
| 492 | if len(args.placeholder_tokens) != len(args.train_data_template): | 489 | if len(args.placeholder_tokens) != len(args.train_data_template): |
| 493 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 490 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
| 491 | |||
| 492 | if args.num_vectors is None: | ||
| 493 | args.num_vectors = [None] * len(args.placeholder_tokens) | ||
| 494 | else: | 494 | else: |
| 495 | if isinstance(args.train_data_template, list): | 495 | if isinstance(args.train_data_template, list): |
| 496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 496 | raise ValueError("--train_data_template can't be a list in simultaneous mode") |
diff --git a/training/functional.py b/training/functional.py index 27a43c2..4565612 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -231,12 +231,16 @@ def add_placeholder_tokens( | |||
| 231 | embeddings: ManagedCLIPTextEmbeddings, | 231 | embeddings: ManagedCLIPTextEmbeddings, |
| 232 | placeholder_tokens: list[str], | 232 | placeholder_tokens: list[str], |
| 233 | initializer_tokens: list[str], | 233 | initializer_tokens: list[str], |
| 234 | num_vectors: Union[list[int], int] | 234 | num_vectors: Optional[Union[list[int], int]] = None, |
| 235 | ): | 235 | ): |
| 236 | initializer_token_ids = [ | 236 | initializer_token_ids = [ |
| 237 | tokenizer.encode(token, add_special_tokens=False) | 237 | tokenizer.encode(token, add_special_tokens=False) |
| 238 | for token in initializer_tokens | 238 | for token in initializer_tokens |
| 239 | ] | 239 | ] |
| 240 | |||
| 241 | if num_vectors is None: | ||
| 242 | num_vectors = [len(ids) for ids in initializer_token_ids] | ||
| 243 | |||
| 240 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | 244 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) |
| 241 | 245 | ||
| 242 | embeddings.resize(len(tokenizer)) | 246 | embeddings.resize(len(tokenizer)) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ccec215..cab5e4c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -11,10 +11,7 @@ from transformers import CLIPTextModel | |||
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
| 12 | from diffusers.loaders import AttnProcsLayers | 12 | from diffusers.loaders import AttnProcsLayers |
| 13 | 13 | ||
| 14 | from slugify import slugify | ||
| 15 | |||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 14 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | ||
| 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 15 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 16 | ||
| 20 | 17 | ||
| @@ -41,16 +38,9 @@ def lora_strategy_callbacks( | |||
| 41 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 38 | sample_output_dir.mkdir(parents=True, exist_ok=True) |
| 42 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 39 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) |
| 43 | 40 | ||
| 44 | weight_dtype = torch.float32 | ||
| 45 | if accelerator.state.mixed_precision == "fp16": | ||
| 46 | weight_dtype = torch.float16 | ||
| 47 | elif accelerator.state.mixed_precision == "bf16": | ||
| 48 | weight_dtype = torch.bfloat16 | ||
| 49 | |||
| 50 | save_samples_ = partial( | 41 | save_samples_ = partial( |
| 51 | save_samples, | 42 | save_samples, |
| 52 | accelerator=accelerator, | 43 | accelerator=accelerator, |
| 53 | unet=unet, | ||
| 54 | text_encoder=text_encoder, | 44 | text_encoder=text_encoder, |
| 55 | tokenizer=tokenizer, | 45 | tokenizer=tokenizer, |
| 56 | vae=vae, | 46 | vae=vae, |
| @@ -83,20 +73,27 @@ def lora_strategy_callbacks( | |||
| 83 | yield | 73 | yield |
| 84 | 74 | ||
| 85 | def on_before_optimize(lr: float, epoch: int): | 75 | def on_before_optimize(lr: float, epoch: int): |
| 86 | if accelerator.sync_gradients: | 76 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) |
| 87 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | ||
| 88 | 77 | ||
| 89 | @torch.no_grad() | 78 | @torch.no_grad() |
| 90 | def on_checkpoint(step, postfix): | 79 | def on_checkpoint(step, postfix): |
| 91 | print(f"Saving checkpoint for step {step}...") | 80 | print(f"Saving checkpoint for step {step}...") |
| 92 | 81 | ||
| 93 | unet_ = accelerator.unwrap_model(unet, False) | 82 | unet_ = accelerator.unwrap_model(unet, False) |
| 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") | 83 | unet_.save_attn_procs( |
| 84 | checkpoint_output_dir / f"{step}_{postfix}", | ||
| 85 | safe_serialization=True | ||
| 86 | ) | ||
| 95 | del unet_ | 87 | del unet_ |
| 96 | 88 | ||
| 97 | @torch.no_grad() | 89 | @torch.no_grad() |
| 98 | def on_sample(step): | 90 | def on_sample(step): |
| 99 | save_samples_(step=step) | 91 | unet_ = accelerator.unwrap_model(unet, False) |
| 92 | save_samples_(step=step, unet=unet_) | ||
| 93 | del unet_ | ||
| 94 | |||
| 95 | if torch.cuda.is_available(): | ||
| 96 | torch.cuda.empty_cache() | ||
| 100 | 97 | ||
| 101 | return TrainingCallbacks( | 98 | return TrainingCallbacks( |
| 102 | on_prepare=on_prepare, | 99 | on_prepare=on_prepare, |
