summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_lora.py159
-rw-r--r--training/lora.py68
2 files changed, 89 insertions, 138 deletions
diff --git a/train_lora.py b/train_lora.py
index ffc1d10..34e1008 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from PIL import Image
20from tqdm.auto import tqdm 19from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 21from slugify import slugify
@@ -24,8 +23,9 @@ from slugify import slugify
24from common import load_text_embeddings 23from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import CSVDataModule 25from data.csv import CSVDataModule
27from training.lora import LoraAttention 26from training.lora import LoraAttnProcessor
28from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
29from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
30 30
31logger = get_logger(__name__) 31logger = get_logger(__name__)
@@ -109,7 +109,7 @@ def parse_args():
109 parser.add_argument( 109 parser.add_argument(
110 "--output_dir", 110 "--output_dir",
111 type=str, 111 type=str,
112 default="output/dreambooth", 112 default="output/lora",
113 help="The output directory where the model predictions and checkpoints will be written.", 113 help="The output directory where the model predictions and checkpoints will be written.",
114 ) 114 )
115 parser.add_argument( 115 parser.add_argument(
@@ -176,7 +176,7 @@ def parse_args():
176 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 176 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
177 ) 177 )
178 parser.add_argument( 178 parser.add_argument(
179 "--learning_rate_unet", 179 "--learning_rate",
180 type=float, 180 type=float,
181 default=2e-6, 181 default=2e-6,
182 help="Initial learning rate (after the potential warmup period) to use.", 182 help="Initial learning rate (after the potential warmup period) to use.",
@@ -348,76 +348,45 @@ def parse_args():
348 return args 348 return args
349 349
350 350
351def save_args(basepath: Path, args, extra={}): 351class Checkpointer(CheckpointerBase):
352 info = {"args": vars(args)}
353 info["args"].update(extra)
354 with open(basepath.joinpath("args.json"), "w") as f:
355 json.dump(info, f, indent=4)
356
357
358def freeze_params(params):
359 for param in params:
360 param.requires_grad = False
361
362
363def make_grid(images, rows, cols):
364 w, h = images[0].size
365 grid = Image.new('RGB', size=(cols*w, rows*h))
366 for i, image in enumerate(images):
367 grid.paste(image, box=(i % cols*w, i//cols*h))
368 return grid
369
370
371class AverageMeter:
372 def __init__(self, name=None):
373 self.name = name
374 self.reset()
375
376 def reset(self):
377 self.sum = self.count = self.avg = 0
378
379 def update(self, val, n=1):
380 self.sum += val * n
381 self.count += n
382 self.avg = self.sum / self.count
383
384
385class Checkpointer:
386 def __init__( 352 def __init__(
387 self, 353 self,
388 datamodule, 354 datamodule,
389 accelerator, 355 accelerator,
390 vae, 356 vae,
391 unet, 357 unet,
392 unet_lora,
393 tokenizer, 358 tokenizer,
394 text_encoder, 359 text_encoder,
360 unet_lora,
395 scheduler, 361 scheduler,
396 output_dir: Path,
397 instance_identifier, 362 instance_identifier,
398 placeholder_token, 363 placeholder_token,
399 placeholder_token_id, 364 placeholder_token_id,
365 output_dir: Path,
400 sample_image_size, 366 sample_image_size,
401 sample_batches, 367 sample_batches,
402 sample_batch_size, 368 sample_batch_size,
403 seed 369 seed
404 ): 370 ):
405 self.datamodule = datamodule 371 super().__init__(
372 datamodule=datamodule,
373 output_dir=output_dir,
374 instance_identifier=instance_identifier,
375 placeholder_token=placeholder_token,
376 placeholder_token_id=placeholder_token_id,
377 sample_image_size=sample_image_size,
378 seed=seed or torch.random.seed(),
379 sample_batches=sample_batches,
380 sample_batch_size=sample_batch_size
381 )
382
406 self.accelerator = accelerator 383 self.accelerator = accelerator
407 self.vae = vae 384 self.vae = vae
408 self.unet = unet 385 self.unet = unet
409 self.unet_lora = unet_lora
410 self.tokenizer = tokenizer 386 self.tokenizer = tokenizer
411 self.text_encoder = text_encoder 387 self.text_encoder = text_encoder
388 self.unet_lora = unet_lora
412 self.scheduler = scheduler 389 self.scheduler = scheduler
413 self.output_dir = output_dir
414 self.instance_identifier = instance_identifier
415 self.placeholder_token = placeholder_token
416 self.placeholder_token_id = placeholder_token_id
417 self.sample_image_size = sample_image_size
418 self.seed = seed or torch.random.seed()
419 self.sample_batches = sample_batches
420 self.sample_batch_size = sample_batch_size
421 390
422 @torch.no_grad() 391 @torch.no_grad()
423 def save_model(self): 392 def save_model(self):
@@ -433,83 +402,18 @@ class Checkpointer:
433 402
434 @torch.no_grad() 403 @torch.no_grad()
435 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 404 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
436 samples_path = Path(self.output_dir).joinpath("samples") 405 # Save a sample image
437
438 unet = self.accelerator.unwrap_model(self.unet)
439 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
440
441 pipeline = VlpnStableDiffusion( 406 pipeline = VlpnStableDiffusion(
442 text_encoder=text_encoder, 407 text_encoder=self.text_encoder,
443 vae=self.vae, 408 vae=self.vae,
444 unet=unet, 409 unet=self.unet,
445 tokenizer=self.tokenizer, 410 tokenizer=self.tokenizer,
446 scheduler=self.scheduler, 411 scheduler=self.scheduler,
447 ).to(self.accelerator.device) 412 ).to(self.accelerator.device)
448 pipeline.set_progress_bar_config(dynamic_ncols=True) 413 pipeline.set_progress_bar_config(dynamic_ncols=True)
449 414
450 train_data = self.datamodule.train_dataloader() 415 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
451 val_data = self.datamodule.val_dataloader()
452 416
453 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
454 stable_latents = torch.randn(
455 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
456 device=pipeline.device,
457 generator=generator,
458 )
459
460 with torch.autocast("cuda"), torch.inference_mode():
461 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
462 all_samples = []
463 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
464 file_path.parent.mkdir(parents=True, exist_ok=True)
465
466 data_enum = enumerate(data)
467
468 batches = [
469 batch
470 for j, batch in data_enum
471 if j * data.batch_size < self.sample_batch_size * self.sample_batches
472 ]
473 prompts = [
474 prompt.format(identifier=self.instance_identifier)
475 for batch in batches
476 for prompt in batch["prompts"]
477 ]
478 nprompts = [
479 prompt
480 for batch in batches
481 for prompt in batch["nprompts"]
482 ]
483
484 for i in range(self.sample_batches):
485 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
486 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
487
488 samples = pipeline(
489 prompt=prompt,
490 negative_prompt=nprompt,
491 height=self.sample_image_size,
492 width=self.sample_image_size,
493 image=latents[:len(prompt)] if latents is not None else None,
494 generator=generator if latents is not None else None,
495 guidance_scale=guidance_scale,
496 eta=eta,
497 num_inference_steps=num_inference_steps,
498 output_type='pil'
499 ).images
500
501 all_samples += samples
502
503 del samples
504
505 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
506 image_grid.save(file_path, quality=85)
507
508 del all_samples
509 del image_grid
510
511 del unet
512 del text_encoder
513 del pipeline 417 del pipeline
514 del generator 418 del generator
515 del stable_latents 419 del stable_latents
@@ -558,7 +462,11 @@ def main():
558 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 462 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
559 args.pretrained_model_name_or_path, subfolder='scheduler') 463 args.pretrained_model_name_or_path, subfolder='scheduler')
560 464
561 unet_lora = LoraAttention() 465 unet_lora = LoraAttnProcessor(
466 cross_attention_dim=unet.cross_attention_dim,
467 inner_dim=unet.in_channels,
468 r=4,
469 )
562 470
563 vae.enable_slicing() 471 vae.enable_slicing()
564 vae.set_use_memory_efficient_attention_xformers(True) 472 vae.set_use_memory_efficient_attention_xformers(True)
@@ -618,8 +526,8 @@ def main():
618 prompt_processor = PromptProcessor(tokenizer, text_encoder) 526 prompt_processor = PromptProcessor(tokenizer, text_encoder)
619 527
620 if args.scale_lr: 528 if args.scale_lr:
621 args.learning_rate_unet = ( 529 args.learning_rate = (
622 args.learning_rate_unet * args.gradient_accumulation_steps * 530 args.learning_rate * args.gradient_accumulation_steps *
623 args.train_batch_size * accelerator.num_processes 531 args.train_batch_size * accelerator.num_processes
624 ) 532 )
625 533
@@ -639,7 +547,7 @@ def main():
639 [ 547 [
640 { 548 {
641 'params': unet_lora.parameters(), 549 'params': unet_lora.parameters(),
642 'lr': args.learning_rate_unet, 550 'lr': args.learning_rate,
643 }, 551 },
644 ], 552 ],
645 betas=(args.adam_beta1, args.adam_beta2), 553 betas=(args.adam_beta1, args.adam_beta2),
@@ -801,7 +709,7 @@ def main():
801 config = vars(args).copy() 709 config = vars(args).copy()
802 config["initializer_token"] = " ".join(config["initializer_token"]) 710 config["initializer_token"] = " ".join(config["initializer_token"])
803 config["placeholder_token"] = " ".join(config["placeholder_token"]) 711 config["placeholder_token"] = " ".join(config["placeholder_token"])
804 accelerator.init_trackers("dreambooth", config=config) 712 accelerator.init_trackers("lora", config=config)
805 713
806 # Train! 714 # Train!
807 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 715 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -832,6 +740,7 @@ def main():
832 tokenizer=tokenizer, 740 tokenizer=tokenizer,
833 text_encoder=text_encoder, 741 text_encoder=text_encoder,
834 scheduler=checkpoint_scheduler, 742 scheduler=checkpoint_scheduler,
743 unet_lora=unet_lora,
835 output_dir=basepath, 744 output_dir=basepath,
836 instance_identifier=instance_identifier, 745 instance_identifier=instance_identifier,
837 placeholder_token=args.placeholder_token, 746 placeholder_token=args.placeholder_token,
diff --git a/training/lora.py b/training/lora.py
index d8dc147..e1c0971 100644
--- a/training/lora.py
+++ b/training/lora.py
@@ -1,27 +1,69 @@
1import torch.nn as nn 1import torch.nn as nn
2from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config
3 2
3from diffusers import ModelMixin, ConfigMixin
4from diffusers.configuration_utils import register_to_config
5from diffusers.models.cross_attention import CrossAttention
6from diffusers.utils.import_utils import is_xformers_available
4 7
5class LoraAttention(ModelMixin, ConfigMixin): 8
9if is_xformers_available():
10 import xformers
11 import xformers.ops
12else:
13 xformers = None
14
15
16class LoraAttnProcessor(ModelMixin, ConfigMixin):
6 @register_to_config 17 @register_to_config
7 def __init__(self, in_features, out_features, r=4): 18 def __init__(
19 self,
20 cross_attention_dim,
21 inner_dim,
22 r: int = 4
23 ):
8 super().__init__() 24 super().__init__()
9 25
10 if r > min(in_features, out_features): 26 if r > min(cross_attention_dim, inner_dim):
11 raise ValueError( 27 raise ValueError(
12 f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" 28 f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}"
13 ) 29 )
14 30
15 self.lora_down = nn.Linear(in_features, r, bias=False) 31 self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False)
16 self.lora_up = nn.Linear(r, out_features, bias=False) 32 self.lora_k_up = nn.Linear(r, inner_dim, bias=False)
33
34 self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False)
35 self.lora_v_up = nn.Linear(r, inner_dim, bias=False)
36
17 self.scale = 1.0 37 self.scale = 1.0
18 38
19 self.processor = XFormersCrossAttnProcessor() 39 nn.init.normal_(self.lora_k_down.weight, std=1 / r**2)
40 nn.init.zeros_(self.lora_k_up.weight)
41
42 nn.init.normal_(self.lora_v_down.weight, std=1 / r**2)
43 nn.init.zeros_(self.lora_v_up.weight)
44
45 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
46 batch_size, sequence_length, _ = hidden_states.shape
47
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
49
50 query = attn.to_q(hidden_states)
51
52 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
53 key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale
54 value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale
55
56 query = attn.head_to_batch_dim(query).contiguous()
57 key = attn.head_to_batch_dim(key).contiguous()
58 value = attn.head_to_batch_dim(value).contiguous()
59
60 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
61 hidden_states = hidden_states.to(query.dtype)
62 hidden_states = attn.batch_to_head_dim(hidden_states)
20 63
21 nn.init.normal_(self.lora_down.weight, std=1 / r**2) 64 # linear proj
22 nn.init.zeros_(self.lora_up.weight) 65 hidden_states = attn.to_out[0](hidden_states)
66 # dropout
67 hidden_states = attn.to_out[1](hidden_states)
23 68
24 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
25 hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number)
26 hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale
27 return hidden_states 69 return hidden_states