summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py10
-rw-r--r--train_dreambooth.py24
-rw-r--r--train_ti.py24
-rw-r--r--training/lora.py92
-rw-r--r--training/lr.py6
5 files changed, 113 insertions, 43 deletions
diff --git a/data/csv.py b/data/csv.py
index af36d9e..e901ab4 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -59,7 +59,7 @@ class CSVDataModule():
59 center_crop: bool = False, 59 center_crop: bool = False,
60 template_key: str = "template", 60 template_key: str = "template",
61 valid_set_size: Optional[int] = None, 61 valid_set_size: Optional[int] = None,
62 generator: Optional[torch.Generator] = None, 62 seed: Optional[int] = None,
63 filter: Optional[Callable[[CSVDataItem], bool]] = None, 63 filter: Optional[Callable[[CSVDataItem], bool]] = None,
64 collate_fn=None, 64 collate_fn=None,
65 num_workers: int = 0 65 num_workers: int = 0
@@ -84,7 +84,7 @@ class CSVDataModule():
84 self.template_key = template_key 84 self.template_key = template_key
85 self.interpolation = interpolation 85 self.interpolation = interpolation
86 self.valid_set_size = valid_set_size 86 self.valid_set_size = valid_set_size
87 self.generator = generator 87 self.seed = seed
88 self.filter = filter 88 self.filter = filter
89 self.collate_fn = collate_fn 89 self.collate_fn = collate_fn
90 self.num_workers = num_workers 90 self.num_workers = num_workers
@@ -155,7 +155,11 @@ class CSVDataModule():
155 valid_set_size = max(valid_set_size, 1) 155 valid_set_size = max(valid_set_size, 1)
156 train_set_size = num_images - valid_set_size 156 train_set_size = num_images - valid_set_size
157 157
158 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) 158 generator = torch.Generator(device="cpu")
159 if self.seed is not None:
160 generator = generator.manual_seed(self.seed)
161
162 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
159 163
160 self.data_train = self.pad_items(data_train, self.num_class_images) 164 self.data_train = self.pad_items(data_train, self.num_class_images)
161 self.data_val = self.pad_items(data_val) 165 self.data_val = self.pad_items(data_val)
diff --git a/train_dreambooth.py b/train_dreambooth.py
index df8b54c..6d9bae8 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -320,6 +320,12 @@ def parse_args():
320 help="Epsilon value for the Adam optimizer" 320 help="Epsilon value for the Adam optimizer"
321 ) 321 )
322 parser.add_argument( 322 parser.add_argument(
323 "--adam_amsgrad",
324 type=bool,
325 default=False,
326 help="Amsgrad value for the Adam optimizer"
327 )
328 parser.add_argument(
323 "--mixed_precision", 329 "--mixed_precision",
324 type=str, 330 type=str,
325 default="no", 331 default="no",
@@ -642,7 +648,7 @@ def main():
642 ) 648 )
643 649
644 if args.find_lr: 650 if args.find_lr:
645 args.learning_rate = 1e-4 651 args.learning_rate = 1e-6
646 652
647 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 653 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
648 if args.use_8bit_adam: 654 if args.use_8bit_adam:
@@ -674,6 +680,7 @@ def main():
674 betas=(args.adam_beta1, args.adam_beta2), 680 betas=(args.adam_beta1, args.adam_beta2),
675 weight_decay=args.adam_weight_decay, 681 weight_decay=args.adam_weight_decay,
676 eps=args.adam_epsilon, 682 eps=args.adam_epsilon,
683 amsgrad=args.adam_amsgrad,
677 ) 684 )
678 685
679 weight_dtype = torch.float32 686 weight_dtype = torch.float32
@@ -730,6 +737,7 @@ def main():
730 template_key=args.train_data_template, 737 template_key=args.train_data_template,
731 valid_set_size=args.valid_set_size, 738 valid_set_size=args.valid_set_size,
732 num_workers=args.dataloader_num_workers, 739 num_workers=args.dataloader_num_workers,
740 seed=args.seed,
733 filter=keyword_filter, 741 filter=keyword_filter,
734 collate_fn=collate_fn 742 collate_fn=collate_fn
735 ) 743 )
@@ -840,7 +848,7 @@ def main():
840 def on_eval(): 848 def on_eval():
841 tokenizer.eval() 849 tokenizer.eval()
842 850
843 def loop(batch): 851 def loop(batch, eval: bool = False):
844 # Convert images to latent space 852 # Convert images to latent space
845 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 853 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
846 latents = latents * 0.18215 854 latents = latents * 0.18215
@@ -849,8 +857,14 @@ def main():
849 noise = torch.randn_like(latents) 857 noise = torch.randn_like(latents)
850 bsz = latents.shape[0] 858 bsz = latents.shape[0]
851 # Sample a random timestep for each image 859 # Sample a random timestep for each image
852 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 860 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None
853 (bsz,), device=latents.device) 861 timesteps = torch.randint(
862 0,
863 noise_scheduler.config.num_train_timesteps,
864 (bsz,),
865 generator=timesteps_gen,
866 device=latents.device,
867 )
854 timesteps = timesteps.long() 868 timesteps = timesteps.long()
855 869
856 # Add noise to the latents according to the noise magnitude at each timestep 870 # Add noise to the latents according to the noise magnitude at each timestep
@@ -1051,7 +1065,7 @@ def main():
1051 1065
1052 with torch.inference_mode(): 1066 with torch.inference_mode():
1053 for step, batch in enumerate(val_dataloader): 1067 for step, batch in enumerate(val_dataloader):
1054 loss, acc, bsz = loop(batch) 1068 loss, acc, bsz = loop(batch, True)
1055 1069
1056 loss = loss.detach_() 1070 loss = loss.detach_()
1057 acc = acc.detach_() 1071 acc = acc.detach_()
diff --git a/train_ti.py b/train_ti.py
index 1685dc4..5d6eafc 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -289,6 +289,12 @@ def parse_args():
289 help="Epsilon value for the Adam optimizer" 289 help="Epsilon value for the Adam optimizer"
290 ) 290 )
291 parser.add_argument( 291 parser.add_argument(
292 "--adam_amsgrad",
293 type=bool,
294 default=False,
295 help="Amsgrad value for the Adam optimizer"
296 )
297 parser.add_argument(
292 "--mixed_precision", 298 "--mixed_precision",
293 type=str, 299 type=str,
294 default="no", 300 default="no",
@@ -592,7 +598,7 @@ def main():
592 ) 598 )
593 599
594 if args.find_lr: 600 if args.find_lr:
595 args.learning_rate = 1e-4 601 args.learning_rate = 1e-6
596 602
597 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 603 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
598 if args.use_8bit_adam: 604 if args.use_8bit_adam:
@@ -612,6 +618,7 @@ def main():
612 betas=(args.adam_beta1, args.adam_beta2), 618 betas=(args.adam_beta1, args.adam_beta2),
613 weight_decay=args.adam_weight_decay, 619 weight_decay=args.adam_weight_decay,
614 eps=args.adam_epsilon, 620 eps=args.adam_epsilon,
621 amsgrad=args.adam_amsgrad,
615 ) 622 )
616 623
617 weight_dtype = torch.float32 624 weight_dtype = torch.float32
@@ -673,6 +680,7 @@ def main():
673 template_key=args.train_data_template, 680 template_key=args.train_data_template,
674 valid_set_size=args.valid_set_size, 681 valid_set_size=args.valid_set_size,
675 num_workers=args.dataloader_num_workers, 682 num_workers=args.dataloader_num_workers,
683 seed=args.seed,
676 filter=keyword_filter, 684 filter=keyword_filter,
677 collate_fn=collate_fn 685 collate_fn=collate_fn
678 ) 686 )
@@ -791,7 +799,7 @@ def main():
791 def on_eval(): 799 def on_eval():
792 tokenizer.eval() 800 tokenizer.eval()
793 801
794 def loop(batch): 802 def loop(batch, eval: bool = False):
795 # Convert images to latent space 803 # Convert images to latent space
796 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 804 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
797 latents = latents * 0.18215 805 latents = latents * 0.18215
@@ -800,8 +808,14 @@ def main():
800 noise = torch.randn_like(latents) 808 noise = torch.randn_like(latents)
801 bsz = latents.shape[0] 809 bsz = latents.shape[0]
802 # Sample a random timestep for each image 810 # Sample a random timestep for each image
803 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 811 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None
804 (bsz,), device=latents.device) 812 timesteps = torch.randint(
813 0,
814 noise_scheduler.config.num_train_timesteps,
815 (bsz,),
816 generator=timesteps_gen,
817 device=latents.device,
818 )
805 timesteps = timesteps.long() 819 timesteps = timesteps.long()
806 820
807 # Add noise to the latents according to the noise magnitude at each timestep 821 # Add noise to the latents according to the noise magnitude at each timestep
@@ -984,7 +998,7 @@ def main():
984 998
985 with torch.inference_mode(): 999 with torch.inference_mode():
986 for step, batch in enumerate(val_dataloader): 1000 for step, batch in enumerate(val_dataloader):
987 loss, acc, bsz = loop(batch) 1001 loss, acc, bsz = loop(batch, True)
988 1002
989 loss = loss.detach_() 1003 loss = loss.detach_()
990 acc = acc.detach_() 1004 acc = acc.detach_()
diff --git a/training/lora.py b/training/lora.py
index e1c0971..3857d78 100644
--- a/training/lora.py
+++ b/training/lora.py
@@ -1,3 +1,4 @@
1import torch
1import torch.nn as nn 2import torch.nn as nn
2 3
3from diffusers import ModelMixin, ConfigMixin 4from diffusers import ModelMixin, ConfigMixin
@@ -13,56 +14,93 @@ else:
13 xformers = None 14 xformers = None
14 15
15 16
16class LoraAttnProcessor(ModelMixin, ConfigMixin): 17class LoRALinearLayer(nn.Module):
17 @register_to_config 18 def __init__(self, in_features, out_features, rank=4):
18 def __init__(
19 self,
20 cross_attention_dim,
21 inner_dim,
22 r: int = 4
23 ):
24 super().__init__() 19 super().__init__()
25 20
26 if r > min(cross_attention_dim, inner_dim): 21 if rank > min(in_features, out_features):
27 raise ValueError( 22 raise ValueError(
28 f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" 23 f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
29 ) 24 )
30 25
31 self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) 26 self.lora_down = nn.Linear(in_features, rank, bias=False)
32 self.lora_k_up = nn.Linear(r, inner_dim, bias=False) 27 self.lora_up = nn.Linear(rank, out_features, bias=False)
28 self.scale = 1.0
33 29
34 self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) 30 nn.init.normal_(self.lora_down.weight, std=1 / rank)
35 self.lora_v_up = nn.Linear(r, inner_dim, bias=False) 31 nn.init.zeros_(self.lora_up.weight)
36 32
37 self.scale = 1.0 33 def forward(self, hidden_states):
34 down_hidden_states = self.lora_down(hidden_states)
35 up_hidden_states = self.lora_up(down_hidden_states)
38 36
39 nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) 37 return up_hidden_states
40 nn.init.zeros_(self.lora_k_up.weight)
41 38
42 nn.init.normal_(self.lora_v_down.weight, std=1 / r**2)
43 nn.init.zeros_(self.lora_v_up.weight)
44 39
45 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 40class LoRACrossAttnProcessor(nn.Module):
46 batch_size, sequence_length, _ = hidden_states.shape 41 def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
42 super().__init__()
47 43
44 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
45 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
46 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
47 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
48
49 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
50 batch_size, sequence_length, _ = hidden_states.shape
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) 51 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
49 52
50 query = attn.to_q(hidden_states) 53 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
54 query = attn.head_to_batch_dim(query)
51 55
52 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 56 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
53 key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale
54 value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale
55 57
58 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
59 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
60
61 key = attn.head_to_batch_dim(key)
62 value = attn.head_to_batch_dim(value)
63
64 attention_probs = attn.get_attention_scores(query, key, attention_mask)
65 hidden_states = torch.bmm(attention_probs, value)
66 hidden_states = attn.batch_to_head_dim(hidden_states)
67
68 # linear proj
69 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
70 # dropout
71 hidden_states = attn.to_out[1](hidden_states)
72
73 return hidden_states
74
75
76class LoRAXFormersCrossAttnProcessor(nn.Module):
77 def __init__(self, hidden_size, cross_attention_dim, rank=4):
78 super().__init__()
79
80 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
81 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
82 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
83 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
84
85 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
86 batch_size, sequence_length, _ = hidden_states.shape
87 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
88
89 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
56 query = attn.head_to_batch_dim(query).contiguous() 90 query = attn.head_to_batch_dim(query).contiguous()
91
92 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
93
94 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
95 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
96
57 key = attn.head_to_batch_dim(key).contiguous() 97 key = attn.head_to_batch_dim(key).contiguous()
58 value = attn.head_to_batch_dim(value).contiguous() 98 value = attn.head_to_batch_dim(value).contiguous()
59 99
60 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 100 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
61 hidden_states = hidden_states.to(query.dtype)
62 hidden_states = attn.batch_to_head_dim(hidden_states)
63 101
64 # linear proj 102 # linear proj
65 hidden_states = attn.to_out[0](hidden_states) 103 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
66 # dropout 104 # dropout
67 hidden_states = attn.to_out[1](hidden_states) 105 hidden_states = attn.to_out[1](hidden_states)
68 106
diff --git a/training/lr.py b/training/lr.py
index 37588b6..a3144ba 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -1,6 +1,6 @@
1import math 1import math
2import copy 2import copy
3from typing import Callable 3from typing import Callable, Any, Tuple, Union
4from functools import partial 4from functools import partial
5 5
6import matplotlib.pyplot as plt 6import matplotlib.pyplot as plt
@@ -24,7 +24,7 @@ class LRFinder():
24 optimizer, 24 optimizer,
25 train_dataloader, 25 train_dataloader,
26 val_dataloader, 26 val_dataloader,
27 loss_fn, 27 loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], None] = noop, 28 on_train: Callable[[], None] = noop,
29 on_eval: Callable[[], None] = noop 29 on_eval: Callable[[], None] = noop
30 ): 30 ):
@@ -108,7 +108,7 @@ class LRFinder():
108 if step >= num_val_batches: 108 if step >= num_val_batches:
109 break 109 break
110 110
111 loss, acc, bsz = self.loss_fn(batch) 111 loss, acc, bsz = self.loss_fn(batch, True)
112 avg_loss.update(loss.detach_(), bsz) 112 avg_loss.update(loss.detach_(), bsz)
113 avg_acc.update(acc.detach_(), bsz) 113 avg_acc.update(acc.detach_(), bsz)
114 114