diff options
| -rw-r--r-- | infer.py | 39 | ||||
| -rw-r--r-- | train_lora.py | 51 | ||||
| -rw-r--r-- | training/functional.py | 40 |
3 files changed, 73 insertions, 57 deletions
| @@ -26,6 +26,8 @@ from diffusers import ( | |||
| 26 | DEISMultistepScheduler, | 26 | DEISMultistepScheduler, |
| 27 | UniPCMultistepScheduler | 27 | UniPCMultistepScheduler |
| 28 | ) | 28 | ) |
| 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict | ||
| 30 | from safetensors.torch import load_file | ||
| 29 | from transformers import CLIPTextModel | 31 | from transformers import CLIPTextModel |
| 30 | 32 | ||
| 31 | from data.keywords import str_to_keywords, keywords_to_str | 33 | from data.keywords import str_to_keywords, keywords_to_str |
| @@ -43,7 +45,7 @@ default_args = { | |||
| 43 | "model": "stabilityai/stable-diffusion-2-1", | 45 | "model": "stabilityai/stable-diffusion-2-1", |
| 44 | "precision": "fp32", | 46 | "precision": "fp32", |
| 45 | "ti_embeddings_dir": "embeddings_ti", | 47 | "ti_embeddings_dir": "embeddings_ti", |
| 46 | "lora_embeddings_dir": "embeddings_lora", | 48 | "lora_embedding": None, |
| 47 | "output_dir": "output/inference", | 49 | "output_dir": "output/inference", |
| 48 | "config": None, | 50 | "config": None, |
| 49 | } | 51 | } |
| @@ -99,7 +101,7 @@ def create_args_parser(): | |||
| 99 | type=str, | 101 | type=str, |
| 100 | ) | 102 | ) |
| 101 | parser.add_argument( | 103 | parser.add_argument( |
| 102 | "--lora_embeddings_dir", | 104 | "--lora_embedding", |
| 103 | type=str, | 105 | type=str, |
| 104 | ) | 106 | ) |
| 105 | parser.add_argument( | 107 | parser.add_argument( |
| @@ -236,6 +238,38 @@ def load_embeddings(pipeline, embeddings_dir): | |||
| 236 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 238 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 237 | 239 | ||
| 238 | 240 | ||
| 241 | def load_lora(pipeline, path): | ||
| 242 | if path is None: | ||
| 243 | return | ||
| 244 | |||
| 245 | path = Path(path) | ||
| 246 | |||
| 247 | with open(path / "lora_config.json", "r") as f: | ||
| 248 | lora_config = json.load(f) | ||
| 249 | |||
| 250 | tensor_files = list(path.glob("*_end.safetensors")) | ||
| 251 | |||
| 252 | if len(tensor_files) == 0: | ||
| 253 | return | ||
| 254 | |||
| 255 | lora_checkpoint_sd = load_file(path / tensor_files[0]) | ||
| 256 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | ||
| 257 | text_encoder_lora_ds = { | ||
| 258 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | ||
| 259 | } | ||
| 260 | |||
| 261 | unet_config = LoraConfig(**lora_config["peft_config"]) | ||
| 262 | pipeline.unet = LoraModel(unet_config, pipeline.unet) | ||
| 263 | set_peft_model_state_dict(pipeline.unet, unet_lora_ds) | ||
| 264 | |||
| 265 | if "text_encoder_peft_config" in lora_config: | ||
| 266 | text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) | ||
| 267 | pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) | ||
| 268 | set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) | ||
| 269 | |||
| 270 | return | ||
| 271 | |||
| 272 | |||
| 239 | def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): | 273 | def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): |
| 240 | if scheduler == "plms": | 274 | if scheduler == "plms": |
| 241 | return PNDMScheduler.from_config(config) | 275 | return PNDMScheduler.from_config(config) |
| @@ -441,6 +475,7 @@ def main(): | |||
| 441 | pipeline = create_pipeline(args.model, dtype) | 475 | pipeline = create_pipeline(args.model, dtype) |
| 442 | 476 | ||
| 443 | load_embeddings(pipeline, args.ti_embeddings_dir) | 477 | load_embeddings(pipeline, args.ti_embeddings_dir) |
| 478 | load_lora(pipeline, args.lora_embedding) | ||
| 444 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 479 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 445 | 480 | ||
| 446 | cmd_parser = create_cmd_parser() | 481 | cmd_parser = create_cmd_parser() |
diff --git a/train_lora.py b/train_lora.py index 73b3e19..1ca56d9 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -1,7 +1,6 @@ | |||
| 1 | import argparse | 1 | import argparse |
| 2 | import datetime | 2 | import datetime |
| 3 | import logging | 3 | import logging |
| 4 | import itertools | ||
| 5 | from pathlib import Path | 4 | from pathlib import Path |
| 6 | from functools import partial | 5 | from functools import partial |
| 7 | import math | 6 | import math |
| @@ -247,9 +246,15 @@ def parse_args(): | |||
| 247 | help="Automatically find a learning rate (no training).", | 246 | help="Automatically find a learning rate (no training).", |
| 248 | ) | 247 | ) |
| 249 | parser.add_argument( | 248 | parser.add_argument( |
| 250 | "--learning_rate", | 249 | "--learning_rate_unet", |
| 251 | type=float, | 250 | type=float, |
| 252 | default=2e-6, | 251 | default=1e-4, |
| 252 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 253 | ) | ||
| 254 | parser.add_argument( | ||
| 255 | "--learning_rate_text", | ||
| 256 | type=float, | ||
| 257 | default=5e-5, | ||
| 253 | help="Initial learning rate (after the potential warmup period) to use.", | 258 | help="Initial learning rate (after the potential warmup period) to use.", |
| 254 | ) | 259 | ) |
| 255 | parser.add_argument( | 260 | parser.add_argument( |
| @@ -548,13 +553,18 @@ def main(): | |||
| 548 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 553 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 549 | 554 | ||
| 550 | if args.scale_lr: | 555 | if args.scale_lr: |
| 551 | args.learning_rate = ( | 556 | args.learning_rate_unet = ( |
| 552 | args.learning_rate * args.gradient_accumulation_steps * | 557 | args.learning_rate_unet * args.gradient_accumulation_steps * |
| 558 | args.train_batch_size * accelerator.num_processes | ||
| 559 | ) | ||
| 560 | args.learning_rate_text = ( | ||
| 561 | args.learning_rate_text * args.gradient_accumulation_steps * | ||
| 553 | args.train_batch_size * accelerator.num_processes | 562 | args.train_batch_size * accelerator.num_processes |
| 554 | ) | 563 | ) |
| 555 | 564 | ||
| 556 | if args.find_lr: | 565 | if args.find_lr: |
| 557 | args.learning_rate = 1e-6 | 566 | args.learning_rate_unet = 1e-6 |
| 567 | args.learning_rate_text = 1e-6 | ||
| 558 | args.lr_scheduler = "exponential_growth" | 568 | args.lr_scheduler = "exponential_growth" |
| 559 | 569 | ||
| 560 | if args.optimizer == 'adam8bit': | 570 | if args.optimizer == 'adam8bit': |
| @@ -611,8 +621,8 @@ def main(): | |||
| 611 | ) | 621 | ) |
| 612 | 622 | ||
| 613 | args.lr_scheduler = "adafactor" | 623 | args.lr_scheduler = "adafactor" |
| 614 | args.lr_min_lr = args.learning_rate | 624 | args.lr_min_lr = args.learning_rate_unet |
| 615 | args.learning_rate = None | 625 | args.learning_rate_unet = None |
| 616 | elif args.optimizer == 'dadam': | 626 | elif args.optimizer == 'dadam': |
| 617 | try: | 627 | try: |
| 618 | import dadaptation | 628 | import dadaptation |
| @@ -628,7 +638,8 @@ def main(): | |||
| 628 | d0=args.dadaptation_d0, | 638 | d0=args.dadaptation_d0, |
| 629 | ) | 639 | ) |
| 630 | 640 | ||
| 631 | args.learning_rate = 1.0 | 641 | args.learning_rate_unet = 1.0 |
| 642 | args.learning_rate_text = 1.0 | ||
| 632 | elif args.optimizer == 'dadan': | 643 | elif args.optimizer == 'dadan': |
| 633 | try: | 644 | try: |
| 634 | import dadaptation | 645 | import dadaptation |
| @@ -642,7 +653,8 @@ def main(): | |||
| 642 | d0=args.dadaptation_d0, | 653 | d0=args.dadaptation_d0, |
| 643 | ) | 654 | ) |
| 644 | 655 | ||
| 645 | args.learning_rate = 1.0 | 656 | args.learning_rate_unet = 1.0 |
| 657 | args.learning_rate_text = 1.0 | ||
| 646 | else: | 658 | else: |
| 647 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 659 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
| 648 | 660 | ||
| @@ -695,15 +707,16 @@ def main(): | |||
| 695 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 707 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) |
| 696 | 708 | ||
| 697 | optimizer = create_optimizer( | 709 | optimizer = create_optimizer( |
| 698 | ( | 710 | [ |
| 699 | param | 711 | { |
| 700 | for param in itertools.chain( | 712 | "params": unet.parameters(), |
| 701 | unet.parameters(), | 713 | "lr": args.learning_rate_unet, |
| 702 | text_encoder.parameters(), | 714 | }, |
| 703 | ) | 715 | { |
| 704 | if param.requires_grad | 716 | "params": text_encoder.parameters(), |
| 705 | ), | 717 | "lr": args.learning_rate_text, |
| 706 | lr=args.learning_rate, | 718 | }, |
| 719 | ] | ||
| 707 | ) | 720 | ) |
| 708 | 721 | ||
| 709 | lr_scheduler = get_scheduler( | 722 | lr_scheduler = get_scheduler( |
diff --git a/training/functional.py b/training/functional.py index 06848cb..c30d1c0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -321,45 +321,13 @@ def loss_step( | |||
| 321 | ) | 321 | ) |
| 322 | 322 | ||
| 323 | if offset_noise_strength != 0: | 323 | if offset_noise_strength != 0: |
| 324 | solid_image = partial( | 324 | offset_noise = torch.randn( |
| 325 | make_solid_image, | 325 | (latents.shape[0], latents.shape[1], 1, 1), |
| 326 | shape=images.shape[1:], | ||
| 327 | vae=vae, | ||
| 328 | dtype=latents.dtype, | 326 | dtype=latents.dtype, |
| 329 | device=latents.device, | 327 | device=latents.device, |
| 330 | generator=generator | 328 | generator=generator |
| 331 | ) | 329 | ).expand(noise.shape) |
| 332 | 330 | noise += offset_noise_strength * offset_noise | |
| 333 | white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" | ||
| 334 | black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}" | ||
| 335 | |||
| 336 | if white_cache_key not in cache: | ||
| 337 | img_white = solid_image(1) | ||
| 338 | cache[white_cache_key] = img_white | ||
| 339 | else: | ||
| 340 | img_white = cache[white_cache_key] | ||
| 341 | |||
| 342 | if black_cache_key not in cache: | ||
| 343 | img_black = solid_image(0) | ||
| 344 | cache[black_cache_key] = img_black | ||
| 345 | else: | ||
| 346 | img_black = cache[black_cache_key] | ||
| 347 | |||
| 348 | offset_strength = torch.rand( | ||
| 349 | (bsz, 1, 1, 1), | ||
| 350 | dtype=latents.dtype, | ||
| 351 | layout=latents.layout, | ||
| 352 | device=latents.device, | ||
| 353 | generator=generator | ||
| 354 | ) | ||
| 355 | offset_strength = offset_noise_strength * (offset_strength * 2 - 1) | ||
| 356 | offset_images = torch.where( | ||
| 357 | offset_strength >= 0, | ||
| 358 | img_white.expand(noise.shape), | ||
| 359 | img_black.expand(noise.shape) | ||
| 360 | ) | ||
| 361 | offset_strength = offset_strength.abs().expand(noise.shape) | ||
| 362 | noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2)) | ||
| 363 | 331 | ||
| 364 | # Sample a random timestep for each image | 332 | # Sample a random timestep for each image |
| 365 | timesteps = torch.randint( | 333 | timesteps = torch.randint( |
