summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py39
-rw-r--r--train_lora.py51
-rw-r--r--training/functional.py40
3 files changed, 73 insertions, 57 deletions
diff --git a/infer.py b/infer.py
index ed86ab1..93848d7 100644
--- a/infer.py
+++ b/infer.py
@@ -26,6 +26,8 @@ from diffusers import (
26 DEISMultistepScheduler, 26 DEISMultistepScheduler,
27 UniPCMultistepScheduler 27 UniPCMultistepScheduler
28) 28)
29from peft import LoraConfig, LoraModel, set_peft_model_state_dict
30from safetensors.torch import load_file
29from transformers import CLIPTextModel 31from transformers import CLIPTextModel
30 32
31from data.keywords import str_to_keywords, keywords_to_str 33from data.keywords import str_to_keywords, keywords_to_str
@@ -43,7 +45,7 @@ default_args = {
43 "model": "stabilityai/stable-diffusion-2-1", 45 "model": "stabilityai/stable-diffusion-2-1",
44 "precision": "fp32", 46 "precision": "fp32",
45 "ti_embeddings_dir": "embeddings_ti", 47 "ti_embeddings_dir": "embeddings_ti",
46 "lora_embeddings_dir": "embeddings_lora", 48 "lora_embedding": None,
47 "output_dir": "output/inference", 49 "output_dir": "output/inference",
48 "config": None, 50 "config": None,
49} 51}
@@ -99,7 +101,7 @@ def create_args_parser():
99 type=str, 101 type=str,
100 ) 102 )
101 parser.add_argument( 103 parser.add_argument(
102 "--lora_embeddings_dir", 104 "--lora_embedding",
103 type=str, 105 type=str,
104 ) 106 )
105 parser.add_argument( 107 parser.add_argument(
@@ -236,6 +238,38 @@ def load_embeddings(pipeline, embeddings_dir):
236 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 238 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
237 239
238 240
241def load_lora(pipeline, path):
242 if path is None:
243 return
244
245 path = Path(path)
246
247 with open(path / "lora_config.json", "r") as f:
248 lora_config = json.load(f)
249
250 tensor_files = list(path.glob("*_end.safetensors"))
251
252 if len(tensor_files) == 0:
253 return
254
255 lora_checkpoint_sd = load_file(path / tensor_files[0])
256 unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
257 text_encoder_lora_ds = {
258 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
259 }
260
261 unet_config = LoraConfig(**lora_config["peft_config"])
262 pipeline.unet = LoraModel(unet_config, pipeline.unet)
263 set_peft_model_state_dict(pipeline.unet, unet_lora_ds)
264
265 if "text_encoder_peft_config" in lora_config:
266 text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
267 pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder)
268 set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds)
269
270 return
271
272
239def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None): 273def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None):
240 if scheduler == "plms": 274 if scheduler == "plms":
241 return PNDMScheduler.from_config(config) 275 return PNDMScheduler.from_config(config)
@@ -441,6 +475,7 @@ def main():
441 pipeline = create_pipeline(args.model, dtype) 475 pipeline = create_pipeline(args.model, dtype)
442 476
443 load_embeddings(pipeline, args.ti_embeddings_dir) 477 load_embeddings(pipeline, args.ti_embeddings_dir)
478 load_lora(pipeline, args.lora_embedding)
444 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 479 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
445 480
446 cmd_parser = create_cmd_parser() 481 cmd_parser = create_cmd_parser()
diff --git a/train_lora.py b/train_lora.py
index 73b3e19..1ca56d9 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -1,7 +1,6 @@
1import argparse 1import argparse
2import datetime 2import datetime
3import logging 3import logging
4import itertools
5from pathlib import Path 4from pathlib import Path
6from functools import partial 5from functools import partial
7import math 6import math
@@ -247,9 +246,15 @@ def parse_args():
247 help="Automatically find a learning rate (no training).", 246 help="Automatically find a learning rate (no training).",
248 ) 247 )
249 parser.add_argument( 248 parser.add_argument(
250 "--learning_rate", 249 "--learning_rate_unet",
251 type=float, 250 type=float,
252 default=2e-6, 251 default=1e-4,
252 help="Initial learning rate (after the potential warmup period) to use.",
253 )
254 parser.add_argument(
255 "--learning_rate_text",
256 type=float,
257 default=5e-5,
253 help="Initial learning rate (after the potential warmup period) to use.", 258 help="Initial learning rate (after the potential warmup period) to use.",
254 ) 259 )
255 parser.add_argument( 260 parser.add_argument(
@@ -548,13 +553,18 @@ def main():
548 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 553 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
549 554
550 if args.scale_lr: 555 if args.scale_lr:
551 args.learning_rate = ( 556 args.learning_rate_unet = (
552 args.learning_rate * args.gradient_accumulation_steps * 557 args.learning_rate_unet * args.gradient_accumulation_steps *
558 args.train_batch_size * accelerator.num_processes
559 )
560 args.learning_rate_text = (
561 args.learning_rate_text * args.gradient_accumulation_steps *
553 args.train_batch_size * accelerator.num_processes 562 args.train_batch_size * accelerator.num_processes
554 ) 563 )
555 564
556 if args.find_lr: 565 if args.find_lr:
557 args.learning_rate = 1e-6 566 args.learning_rate_unet = 1e-6
567 args.learning_rate_text = 1e-6
558 args.lr_scheduler = "exponential_growth" 568 args.lr_scheduler = "exponential_growth"
559 569
560 if args.optimizer == 'adam8bit': 570 if args.optimizer == 'adam8bit':
@@ -611,8 +621,8 @@ def main():
611 ) 621 )
612 622
613 args.lr_scheduler = "adafactor" 623 args.lr_scheduler = "adafactor"
614 args.lr_min_lr = args.learning_rate 624 args.lr_min_lr = args.learning_rate_unet
615 args.learning_rate = None 625 args.learning_rate_unet = None
616 elif args.optimizer == 'dadam': 626 elif args.optimizer == 'dadam':
617 try: 627 try:
618 import dadaptation 628 import dadaptation
@@ -628,7 +638,8 @@ def main():
628 d0=args.dadaptation_d0, 638 d0=args.dadaptation_d0,
629 ) 639 )
630 640
631 args.learning_rate = 1.0 641 args.learning_rate_unet = 1.0
642 args.learning_rate_text = 1.0
632 elif args.optimizer == 'dadan': 643 elif args.optimizer == 'dadan':
633 try: 644 try:
634 import dadaptation 645 import dadaptation
@@ -642,7 +653,8 @@ def main():
642 d0=args.dadaptation_d0, 653 d0=args.dadaptation_d0,
643 ) 654 )
644 655
645 args.learning_rate = 1.0 656 args.learning_rate_unet = 1.0
657 args.learning_rate_text = 1.0
646 else: 658 else:
647 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 659 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"")
648 660
@@ -695,15 +707,16 @@ def main():
695 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 707 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
696 708
697 optimizer = create_optimizer( 709 optimizer = create_optimizer(
698 ( 710 [
699 param 711 {
700 for param in itertools.chain( 712 "params": unet.parameters(),
701 unet.parameters(), 713 "lr": args.learning_rate_unet,
702 text_encoder.parameters(), 714 },
703 ) 715 {
704 if param.requires_grad 716 "params": text_encoder.parameters(),
705 ), 717 "lr": args.learning_rate_text,
706 lr=args.learning_rate, 718 },
719 ]
707 ) 720 )
708 721
709 lr_scheduler = get_scheduler( 722 lr_scheduler = get_scheduler(
diff --git a/training/functional.py b/training/functional.py
index 06848cb..c30d1c0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -321,45 +321,13 @@ def loss_step(
321 ) 321 )
322 322
323 if offset_noise_strength != 0: 323 if offset_noise_strength != 0:
324 solid_image = partial( 324 offset_noise = torch.randn(
325 make_solid_image, 325 (latents.shape[0], latents.shape[1], 1, 1),
326 shape=images.shape[1:],
327 vae=vae,
328 dtype=latents.dtype, 326 dtype=latents.dtype,
329 device=latents.device, 327 device=latents.device,
330 generator=generator 328 generator=generator
331 ) 329 ).expand(noise.shape)
332 330 noise += offset_noise_strength * offset_noise
333 white_cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}"
334 black_cache_key = f"img_black_{images.shape[2]}_{images.shape[3]}"
335
336 if white_cache_key not in cache:
337 img_white = solid_image(1)
338 cache[white_cache_key] = img_white
339 else:
340 img_white = cache[white_cache_key]
341
342 if black_cache_key not in cache:
343 img_black = solid_image(0)
344 cache[black_cache_key] = img_black
345 else:
346 img_black = cache[black_cache_key]
347
348 offset_strength = torch.rand(
349 (bsz, 1, 1, 1),
350 dtype=latents.dtype,
351 layout=latents.layout,
352 device=latents.device,
353 generator=generator
354 )
355 offset_strength = offset_noise_strength * (offset_strength * 2 - 1)
356 offset_images = torch.where(
357 offset_strength >= 0,
358 img_white.expand(noise.shape),
359 img_black.expand(noise.shape)
360 )
361 offset_strength = offset_strength.abs().expand(noise.shape)
362 noise = slerp(noise, offset_images, offset_strength, zdim=(-1, -2))
363 331
364 # Sample a random timestep for each image 332 # Sample a random timestep for each image
365 timesteps = torch.randint( 333 timesteps = torch.randint(